Skip to content

Commit b5e93fc

Browse files
authored
perf: optimize PikeVM by matching multiple literal bytes in tight loop when possible (VirusTotal#678)
This PR introduces an optimization to PikeVM matching for contiguous runs of literal bytes, speeding up regex matching for patterns containing literal sequences. Introduced Instr::Bytes(LiteralBytesIter) to group contiguous literal bytes and match them in a single fast loop, bypassing the VM thread scheduling, state transitions, and epsilon-closure overhead for every byte. Built a lazy iterator LiteralBytesIter that decodes bytecode literals on-the-fly, resolving escaped OPCODE_PREFIX (0xAA) sequences and stopping at unescaped control opcodes. The optimization is enabled only when there is exactly one active thread in the VM (self.threads.len() == 1). This prevents desynchronizing other threads that might be matching at different offsets or branches if a run consumes multiple bytes.
1 parent 7330d40 commit b5e93fc

4 files changed

Lines changed: 285 additions & 13 deletions

File tree

lib/src/re/thompson/compiler.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,10 @@ impl Display for InstrSeq {
17231723
Instr::Byte(byte) => {
17241724
writeln!(f, "{addr:05x}: LIT {byte:#04x}")?;
17251725
}
1726+
Instr::Bytes(iter) => {
1727+
let bytes: Vec<u8> = iter.clone().collect();
1728+
writeln!(f, "{addr:05x}: BYTES {:?}", bytes)?;
1729+
}
17261730
Instr::MaskedByte { byte, mask } => {
17271731
writeln!(
17281732
f,

lib/src/re/thompson/instr.rs

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ pub enum Instr<'a> {
171171
/// Matches a specific byte.
172172
Byte(u8),
173173

174+
/// Matches a sequence of bytes.
175+
Bytes(LiteralBytesIter<'a>),
176+
174177
/// Matches a case-insensitive character. The value of `u8` is in the
175178
/// range a-z.
176179
CaseInsensitiveChar(u8),
@@ -320,7 +323,31 @@ impl<'a> InstrParser<'a> {
320323
}
321324

322325
#[inline(always)]
323-
pub fn decode_instr(code: &[u8]) -> (Instr<'_>, usize) {
326+
fn is_bytes_run(code: &[u8]) -> bool {
327+
match code {
328+
[
329+
OPCODE_PREFIX,
330+
OPCODE_PREFIX,
331+
OPCODE_PREFIX,
332+
OPCODE_PREFIX,
333+
..,
334+
] => true,
335+
[OPCODE_PREFIX, OPCODE_PREFIX, x, ..] if *x != OPCODE_PREFIX => {
336+
true
337+
}
338+
[x, OPCODE_PREFIX, OPCODE_PREFIX, ..] if *x != OPCODE_PREFIX => {
339+
true
340+
}
341+
[x, y, ..] if *x != OPCODE_PREFIX && *y != OPCODE_PREFIX => true,
342+
_ => false,
343+
}
344+
}
345+
346+
#[inline(always)]
347+
pub fn decode_instr(
348+
code: &[u8],
349+
decode_literal_runs: bool,
350+
) -> (Instr<'_>, usize) {
324351
match code[..] {
325352
[OPCODE_PREFIX, Instr::ANY_BYTE, ..] => (Instr::AnyByte, 2),
326353
[OPCODE_PREFIX, Instr::MASKED_BYTE, byte, mask, ..] => {
@@ -425,10 +452,15 @@ impl<'a> InstrParser<'a> {
425452
[OPCODE_PREFIX, Instr::WORD_START, ..] => (Instr::WordStart, 2),
426453
[OPCODE_PREFIX, Instr::WORD_END, ..] => (Instr::WordEnd, 2),
427454
[OPCODE_PREFIX, Instr::MATCH, ..] => (Instr::Match, 2),
428-
[OPCODE_PREFIX, OPCODE_PREFIX, ..] => {
429-
(Instr::Byte(OPCODE_PREFIX), 2)
455+
[_, ..] => {
456+
if decode_literal_runs && Self::is_bytes_run(code) {
457+
(Instr::Bytes(LiteralBytesIter::new(code)), 0)
458+
} else if code[0] == OPCODE_PREFIX {
459+
(Instr::Byte(OPCODE_PREFIX), 2)
460+
} else {
461+
(Instr::Byte(code[0]), 1)
462+
}
430463
}
431-
[b, ..] => (Instr::Byte(b), 1),
432464
_ => unreachable!(),
433465
}
434466
}
@@ -473,7 +505,7 @@ impl<'a> Iterator for InstrParser<'a> {
473505
if self.code.is_empty() {
474506
return None;
475507
}
476-
let (instr, size) = InstrParser::decode_instr(self.code);
508+
let (instr, size) = InstrParser::decode_instr(self.code, false);
477509
let addr = self.addr;
478510
self.addr += size;
479511
self.code = &self.code[size..];
@@ -595,3 +627,122 @@ pub fn literal_code_length(literal: &[u8]) -> usize {
595627
}
596628
length
597629
}
630+
631+
#[derive(Clone, Debug)]
632+
pub struct LiteralBytesIter<'a> {
633+
code: &'a [u8],
634+
offset: usize,
635+
}
636+
637+
impl<'a> LiteralBytesIter<'a> {
638+
#[inline(always)]
639+
pub fn new(code: &'a [u8]) -> Self {
640+
Self { code, offset: 0 }
641+
}
642+
643+
#[inline(always)]
644+
pub fn consumed(&self) -> usize {
645+
self.offset
646+
}
647+
}
648+
649+
impl<'a> Iterator for LiteralBytesIter<'a> {
650+
type Item = u8;
651+
652+
#[inline(always)]
653+
fn next(&mut self) -> Option<Self::Item> {
654+
if self.offset >= self.code.len() {
655+
return None;
656+
}
657+
match self.code[self.offset..] {
658+
[OPCODE_PREFIX, OPCODE_PREFIX, ..] => {
659+
self.offset += 2;
660+
Some(OPCODE_PREFIX)
661+
}
662+
[OPCODE_PREFIX, ..] => None,
663+
[byte, ..] => {
664+
self.offset += 1;
665+
Some(byte)
666+
}
667+
_ => None,
668+
}
669+
}
670+
}
671+
672+
#[cfg(test)]
673+
mod tests {
674+
use super::*;
675+
676+
#[test]
677+
fn test_decode_instr() {
678+
// Run of standard literal bytes
679+
let code = [0x01, 0x02, OPCODE_PREFIX, Instr::SPLIT_B];
680+
let (instr, size) = InstrParser::decode_instr(&code, true);
681+
assert_eq!(size, 0);
682+
if let Instr::Bytes(mut iter) = instr {
683+
assert_eq!(iter.next(), Some(0x01));
684+
assert_eq!(iter.next(), Some(0x02));
685+
assert_eq!(iter.next(), None);
686+
assert_eq!(iter.consumed(), 2);
687+
} else {
688+
panic!("Expected Instr::Bytes");
689+
}
690+
691+
let (instr, size) = InstrParser::decode_instr(&code, false);
692+
assert_eq!(size, 1);
693+
assert!(matches!(instr, Instr::Byte(0x01)));
694+
695+
// Run of escaped OPCODE_PREFIX bytes
696+
let code =
697+
[OPCODE_PREFIX, OPCODE_PREFIX, OPCODE_PREFIX, OPCODE_PREFIX];
698+
let (instr, size) = InstrParser::decode_instr(&code, true);
699+
assert_eq!(size, 0);
700+
if let Instr::Bytes(mut iter) = instr {
701+
assert_eq!(iter.next(), Some(OPCODE_PREFIX));
702+
assert_eq!(iter.next(), Some(OPCODE_PREFIX));
703+
assert_eq!(iter.next(), None);
704+
assert_eq!(iter.consumed(), 4);
705+
} else {
706+
panic!("Expected Instr::Bytes");
707+
}
708+
709+
// Run with escaped OPCODE_PREFIX and standard literal byte
710+
let code = [OPCODE_PREFIX, OPCODE_PREFIX, 0x01];
711+
let (instr, size) = InstrParser::decode_instr(&code, true);
712+
assert_eq!(size, 0);
713+
if let Instr::Bytes(mut iter) = instr {
714+
assert_eq!(iter.next(), Some(OPCODE_PREFIX));
715+
assert_eq!(iter.next(), Some(0x01));
716+
assert_eq!(iter.next(), None);
717+
assert_eq!(iter.consumed(), 3);
718+
} else {
719+
panic!("Expected Instr::Bytes");
720+
}
721+
722+
// Run with standard literal byte and escaped OPCODE_PREFIX
723+
let code = [0x01, OPCODE_PREFIX, OPCODE_PREFIX];
724+
let (instr, size) = InstrParser::decode_instr(&code, true);
725+
assert_eq!(size, 0);
726+
if let Instr::Bytes(mut iter) = instr {
727+
assert_eq!(iter.next(), Some(0x01));
728+
assert_eq!(iter.next(), Some(OPCODE_PREFIX));
729+
assert_eq!(iter.next(), None);
730+
assert_eq!(iter.consumed(), 3);
731+
} else {
732+
panic!("Expected Instr::Bytes");
733+
}
734+
735+
// Single literal byte (not a run of >= 2)
736+
let code = [0x01, OPCODE_PREFIX, Instr::SPLIT_B];
737+
let (instr, size) = InstrParser::decode_instr(&code, true);
738+
assert_eq!(size, 1);
739+
assert!(matches!(instr, Instr::Byte(0x01)));
740+
741+
// Single escaped OPCODE_PREFIX (not a run of >= 2)
742+
let code =
743+
[OPCODE_PREFIX, OPCODE_PREFIX, OPCODE_PREFIX, Instr::SPLIT_B];
744+
let (instr, size) = InstrParser::decode_instr(&code, true);
745+
assert_eq!(size, 2);
746+
assert!(matches!(instr, Instr::Byte(OPCODE_PREFIX)));
747+
}
748+
}

lib/src/re/thompson/pikevm.rs

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ impl<'r> PikeVM<'r> {
162162
F: Iterator<Item = &'a u8>,
163163
B: Iterator<Item = &'a u8>,
164164
{
165-
let step = 1;
166165
let mut current_pos = 0;
167166
let mut curr_byte = fwd_input.next();
168167

@@ -181,18 +180,87 @@ impl<'r> PikeVM<'r> {
181180
);
182181

183182
while !self.threads.is_empty() {
184-
let next_byte = fwd_input.next();
183+
let mut next_byte = fwd_input.next();
184+
// When there is only a single active thread in the VM (that is,
185+
// `self.threads.len() == 1`), we can optimize execution by
186+
// decoding a contiguous run of literal bytes (`Instr::Bytes`) and
187+
// matching them in a fast loop.
188+
//
189+
// This is only safe when there is one thread active. If there were
190+
// multiple concurrent threads, matching a byte run would consume
191+
// multiple bytes from the input on the fly, which would
192+
// desynchronize and bypass other threads matching at different
193+
// positions or branches. It's safe to set decode_literal_runs
194+
// always to false, it will simply disable the optimization.
195+
let decode_literal_runs = self.threads.len() == 1;
185196

186197
for (ip, rep_count) in self.threads.iter() {
187-
let (instr, instr_size) = InstrParser::decode_instr(unsafe {
188-
self.code.get_unchecked(*ip..)
189-
});
198+
let (instr, mut instr_size) = InstrParser::decode_instr(
199+
unsafe { self.code.get_unchecked(*ip..) },
200+
decode_literal_runs,
201+
);
190202

191203
let is_match = match instr {
192204
Instr::AnyByte => curr_byte.is_some(),
193205
Instr::Byte(byte) => {
194206
matches!(curr_byte, Some(b) if *b == byte)
195207
}
208+
// `Instr::Bytes` matches a sequence of literal bytes in
209+
// a single VM step. This bypasses standard VM thread
210+
// scheduling and state updates, matching the sequence
211+
// directly against the input stream in a fast loop. This
212+
// is returned only when decode_literal_runs is true.
213+
Instr::Bytes(mut lit_bytes) => {
214+
let is_match = 'is_match: {
215+
let first = match lit_bytes.next() {
216+
Some(first) => first,
217+
None => break 'is_match false,
218+
};
219+
220+
if !matches!(curr_byte, Some(b) if *b == first) {
221+
break 'is_match false;
222+
}
223+
224+
let second = match lit_bytes.next() {
225+
Some(second) => second,
226+
None => break 'is_match true,
227+
};
228+
229+
if !matches!(next_byte, Some(b) if *b == second) {
230+
break 'is_match false;
231+
}
232+
233+
curr_byte = next_byte;
234+
current_pos += 1;
235+
236+
// Match the remaining literal bytes in the
237+
// sequence by consuming bytes from the input
238+
// stream.
239+
for expected_byte in lit_bytes.by_ref() {
240+
curr_byte = fwd_input.next();
241+
match curr_byte {
242+
Some(curr_byte) => {
243+
current_pos += 1;
244+
if *curr_byte != expected_byte {
245+
break 'is_match false;
246+
}
247+
}
248+
None => break 'is_match false,
249+
}
250+
}
251+
252+
next_byte = fwd_input.next();
253+
break 'is_match true;
254+
};
255+
256+
// Since the instruction size is not known
257+
// statically when decoding `Instr::Bytes`, we
258+
// retrieve the number of consumed bytecode bytes
259+
// from the iterator to advance the instruction
260+
// pointer correctly.
261+
instr_size = lit_bytes.consumed();
262+
is_match
263+
}
196264
Instr::MaskedByte { byte, mask } => {
197265
matches!(curr_byte, Some(b) if *b & mask == byte)
198266
}
@@ -226,7 +294,7 @@ impl<'r> PikeVM<'r> {
226294
}
227295

228296
curr_byte = next_byte;
229-
current_pos += step;
297+
current_pos += 1;
230298

231299
mem::swap(&mut self.threads, &mut self.next_threads);
232300
self.next_threads.clear();
@@ -328,11 +396,14 @@ pub(crate) fn epsilon_closure<C: CodeLoc>(
328396
};
329397

330398
while let Some((ip, mut rep_count)) = state.threads.pop() {
331-
let (instr, instr_size) =
332-
InstrParser::decode_instr(unsafe { code.get_unchecked(ip..) });
399+
let (instr, instr_size) = InstrParser::decode_instr(
400+
unsafe { code.get_unchecked(ip..) },
401+
false,
402+
);
333403
match instr {
334404
Instr::AnyByte
335405
| Instr::Byte(_)
406+
| Instr::Bytes(_)
336407
| Instr::MaskedByte { .. }
337408
| Instr::CaseInsensitiveChar(_)
338409
| Instr::ClassBitmap(_)

lib/src/scanner/tests.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,3 +1053,49 @@ fn fast_scan_mode() {
10531053
test_count.patterns().filter(|p| p.identifier() == "$c");
10541054
assert_eq!(patterns_c.next().unwrap().matches().len(), 2);
10551055
}
1056+
1057+
#[test]
1058+
fn test_pikevm_literal_run_optimization() {
1059+
let rules = crate::compile(
1060+
r#"
1061+
rule test_opt {
1062+
strings:
1063+
$a = /abcdefg.*hijk.*lmno/
1064+
condition:
1065+
$a
1066+
}
1067+
"#,
1068+
)
1069+
.unwrap();
1070+
1071+
let mut scanner = Scanner::new(&rules);
1072+
1073+
let results = scanner.scan(b"abcdefg_hijk_lmno").unwrap();
1074+
assert_eq!(results.matching_rules().count(), 1);
1075+
1076+
let results = scanner.scan(b"abcdefg_hijk_lmn").unwrap();
1077+
assert_eq!(results.matching_rules().count(), 0);
1078+
1079+
let results = scanner.scan(b"abcdef_hijk_lmno").unwrap();
1080+
assert_eq!(results.matching_rules().count(), 0);
1081+
}
1082+
1083+
#[test]
1084+
fn test_slow_rule_hang() {
1085+
let rules = crate::compile(
1086+
r#"
1087+
rule test {
1088+
strings:
1089+
$zero_padding = /\x00{860,}/
1090+
condition:
1091+
$zero_padding
1092+
}
1093+
"#,
1094+
)
1095+
.unwrap();
1096+
1097+
let mut scanner = Scanner::new(&rules);
1098+
let data = vec![0u8; 2000];
1099+
let results = scanner.scan(&data).unwrap();
1100+
assert_eq!(results.matching_rules().count(), 1);
1101+
}

0 commit comments

Comments
 (0)