Skip to content

Commit 20b9849

Browse files
Parser: fix exponential parse time on speculative prefix parsing (apache#2352)
1 parent 928783d commit 20b9849

3 files changed

Lines changed: 189 additions & 4 deletions

File tree

sqlparser_bench/benches/sqlparser_bench.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use criterion::{criterion_group, criterion_main, Criterion};
19-
use sqlparser::dialect::GenericDialect;
19+
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect};
2020
use sqlparser::keywords::Keyword;
2121
use sqlparser::parser::Parser;
2222
use sqlparser::tokenizer::{Span, Word};
@@ -200,12 +200,59 @@ fn parse_compound_keyword_chain(c: &mut Criterion) {
200200
group.finish();
201201
}
202202

203+
/// Benchmark parsing pathological `IF(<keyword-fn>(<keyword-fn>(...x` chains
204+
/// that previously caused 2^N work in `parse_prefix`. Each nested
205+
/// `current_time(` segment used to be explored twice at every level (once via
206+
/// the speculative reserved-word arm, once via the unreserved-word fallback),
207+
/// doubling work per level. Post-fix the cost is linear in chain length.
208+
fn parse_prefix_keyword_call_chain(c: &mut Criterion) {
209+
let mut group = c.benchmark_group("parse_prefix_keyword_call_chain");
210+
let dialect = PostgreSqlDialect {};
211+
212+
for &n in &[10usize, 20, 30] {
213+
let sql = String::from("if(") + &"current_time(".repeat(n) + "x";
214+
215+
group.bench_function(format!("chain_{n}"), |b| {
216+
b.iter(|| {
217+
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
218+
});
219+
});
220+
}
221+
222+
group.finish();
223+
}
224+
225+
/// Benchmark parsing pathological `case-case-case-...c` chains that
226+
/// previously caused 2^N work in `parse_prefix`. Each `case` token used to
227+
/// trigger a speculative `parse_case_expr` that recursively descends the
228+
/// chain, but the unreserved-word fallback returns `Identifier(case)` so the
229+
/// overall `parse_prefix` succeeds and the failure cache never fires.
230+
/// Post-fix the per-arm cache short-circuits the speculative descent.
231+
fn parse_prefix_case_chain(c: &mut Criterion) {
232+
let mut group = c.benchmark_group("parse_prefix_case_chain");
233+
let dialect = SQLiteDialect {};
234+
235+
for &n in &[10usize, 20, 30] {
236+
let sql = "case\t-".repeat(n) + "c";
237+
238+
group.bench_function(format!("chain_{n}"), |b| {
239+
b.iter(|| {
240+
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
241+
});
242+
});
243+
}
244+
245+
group.finish();
246+
}
247+
203248
criterion_group!(
204249
benches,
205250
basic_queries,
206251
word_to_ident,
207252
parse_many_identifiers,
208253
parse_compound_chain,
209-
parse_compound_keyword_chain
254+
parse_compound_keyword_chain,
255+
parse_prefix_keyword_call_chain,
256+
parse_prefix_case_chain
210257
);
211258
criterion_main!(benches);

src/parser/mod.rs

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#[cfg(not(feature = "std"))]
1616
use alloc::{
1717
boxed::Box,
18+
collections::BTreeMap,
1819
format,
1920
string::{String, ToString},
2021
vec,
@@ -24,6 +25,9 @@ use core::{
2425
fmt::{self, Display},
2526
str::FromStr,
2627
};
28+
#[cfg(feature = "std")]
29+
use std::collections::BTreeMap;
30+
2731
use helpers::attached_token::AttachedToken;
2832

2933
use log::debug;
@@ -359,6 +363,29 @@ pub struct Parser<'a> {
359363
options: ParserOptions,
360364
/// Ensures the stack does not overflow by limiting recursion depth.
361365
recursion_counter: RecursionCounter,
366+
/// Cached failures from `parse_prefix` calls that returned `Err`. See
367+
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
368+
failed_prefix_positions: BTreeMap<usize, ExprPrefixError>,
369+
/// Cached failures from the speculative reserved-word prefix arm. See
370+
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
371+
failed_reserved_word_prefix_positions: BTreeMap<usize, ExprPrefixError>,
372+
}
373+
374+
/// Copy marker for a [`ParserError`] cached by the `parse_prefix` failure
375+
/// memoization, so the caches hold no strings.
376+
#[derive(Debug, Clone, Copy)]
377+
enum ExprPrefixError {
378+
RecursionLimitExceeded,
379+
Err,
380+
}
381+
382+
impl From<&ParserError> for ExprPrefixError {
383+
fn from(e: &ParserError) -> Self {
384+
match e {
385+
ParserError::RecursionLimitExceeded => Self::RecursionLimitExceeded,
386+
_ => Self::Err,
387+
}
388+
}
362389
}
363390

364391
impl<'a> Parser<'a> {
@@ -385,6 +412,8 @@ impl<'a> Parser<'a> {
385412
dialect,
386413
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
387414
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
415+
failed_prefix_positions: BTreeMap::new(),
416+
failed_reserved_word_prefix_positions: BTreeMap::new(),
388417
}
389418
}
390419

@@ -446,6 +475,8 @@ impl<'a> Parser<'a> {
446475
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithSpan>) -> Self {
447476
self.tokens = tokens;
448477
self.index = 0;
478+
self.failed_prefix_positions.clear();
479+
self.failed_reserved_word_prefix_positions.clear();
449480
self
450481
}
451482

@@ -1717,6 +1748,35 @@ impl<'a> Parser<'a> {
17171748
return prefix;
17181749
}
17191750

1751+
// Memoize parse_prefix failures to break 2^N speculation when both
1752+
// prefix arms fail at every level (e.g. `IF(current_time(...x`).
1753+
// The per-arm cache in `parse_prefix_inner` complements this for
1754+
// chains where the reserved arm fails but the unreserved fallback
1755+
// succeeds (e.g. `case-case-...c`).
1756+
let start_index = self.index;
1757+
if let Some(&cached) = self.failed_prefix_positions.get(&start_index) {
1758+
return self.cached_prefix_error(cached, self.peek_token_ref());
1759+
}
1760+
let result = self.parse_prefix_inner();
1761+
if let Err(ref e) = result {
1762+
self.failed_prefix_positions.insert(start_index, e.into());
1763+
}
1764+
result
1765+
}
1766+
1767+
/// Rebuild the error for a cached prefix failure at the `found` token.
1768+
fn cached_prefix_error<T>(
1769+
&self,
1770+
cached: ExprPrefixError,
1771+
found: &TokenWithSpan,
1772+
) -> Result<T, ParserError> {
1773+
match cached {
1774+
ExprPrefixError::RecursionLimitExceeded => Err(ParserError::RecursionLimitExceeded),
1775+
ExprPrefixError::Err => self.expected_ref("an expression", found),
1776+
}
1777+
}
1778+
1779+
fn parse_prefix_inner(&mut self) -> Result<Expr, ParserError> {
17201780
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
17211781
// string literal represents a literal of that type. Some examples:
17221782
//
@@ -1801,7 +1861,21 @@ impl<'a> Parser<'a> {
18011861
// We first try to parse the word and following tokens as a special expression, and if that fails,
18021862
// we rollback and try to parse it as an identifier.
18031863
let w = w.clone();
1804-
match self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span)) {
1864+
// Memoize failed speculative reserved-word parses. When
1865+
// the reserved arm (CASE, CURRENT_TIME, etc.) does
1866+
// exponential work but the unreserved fallback ultimately
1867+
// succeeds, the overall `parse_prefix` returns `Ok` and the
1868+
// outer cache never fires. Chains like `case-case-...c`
1869+
// need this per-arm cache to break the doubling.
1870+
let try_parse_result = if let Some(&cached) = self
1871+
.failed_reserved_word_prefix_positions
1872+
.get(&next_token_index)
1873+
{
1874+
self.cached_prefix_error(cached, self.get_current_token())
1875+
} else {
1876+
self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span))
1877+
};
1878+
match try_parse_result {
18051879
// This word indicated an expression prefix and parsing was successful
18061880
Ok(Some(expr)) => Ok(expr),
18071881

@@ -1815,6 +1889,8 @@ impl<'a> Parser<'a> {
18151889
// we rollback and return the parsing error we got from trying to parse a
18161890
// special expression (to maintain backwards compatibility of parsing errors).
18171891
Err(e) => {
1892+
self.failed_reserved_word_prefix_positions
1893+
.insert(next_token_index, (&e).into());
18181894
if !self.dialect.is_reserved_for_identifier(w.keyword) {
18191895
if let Ok(Some(expr)) = self.maybe_parse(|parser| {
18201896
parser.parse_expr_prefix_by_unreserved_word(&w, span)

tests/sqlparser_common.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15572,7 +15572,10 @@ fn parse_create_table_select() {
1557215572

1557315573
#[test]
1557415574
fn test_reserved_keywords_for_identifiers() {
15575-
let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL));
15575+
let dialects = all_dialects_where(|d| {
15576+
d.is_reserved_for_identifier(Keyword::INTERVAL)
15577+
&& !d.supports_named_fn_args_with_expr_name()
15578+
});
1557615579
// Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier
1557715580
let sql = "SELECT MAX(interval) FROM tbl";
1557815581
assert_eq!(
@@ -15582,6 +15585,19 @@ fn test_reserved_keywords_for_identifiers() {
1558215585
))
1558315586
);
1558415587

15588+
// Dialects with expression-named function arguments parse the argument
15589+
// expression twice, so the second attempt reports the memoized failure
15590+
// at the start of the expression
15591+
let dialects = all_dialects_where(|d| {
15592+
d.is_reserved_for_identifier(Keyword::INTERVAL) && d.supports_named_fn_args_with_expr_name()
15593+
});
15594+
assert_eq!(
15595+
dialects.parse_sql_statements(sql),
15596+
Err(ParserError::ParserError(
15597+
"Expected: an expression, found: interval".to_string()
15598+
))
15599+
);
15600+
1558515601
// Dialects that do not reserve the word INTERVAL will allow it
1558615602
let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL));
1558715603
let sql = "SELECT MAX(interval) FROM tbl";
@@ -19035,3 +19051,49 @@ fn parse_compound_keyword_chain_no_exponential_blowup() {
1903519051
rx.recv_timeout(Duration::from_secs(5))
1903619052
.expect("parser should handle this quickly, not loop exponentially");
1903719053
}
19054+
19055+
/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
19056+
/// like `IF(current_time(current_time(...x`. Each nested `current_time(` used
19057+
/// to be explored twice at every level (once via the speculative reserved-word
19058+
/// arm, once via the unreserved-word fallback), doubling work per level.
19059+
/// Post-fix the failing parse short-circuits via the position-keyed cache.
19060+
#[test]
19061+
fn parse_prefix_keyword_call_chain_no_exponential_blowup() {
19062+
use std::sync::mpsc;
19063+
use std::thread;
19064+
use std::time::Duration;
19065+
19066+
let sql = String::from("if(") + &"current_time(".repeat(30) + "x";
19067+
19068+
let (tx, rx) = mpsc::channel();
19069+
thread::spawn(move || {
19070+
let _ = Parser::parse_sql(&PostgreSqlDialect {}, &sql);
19071+
let _ = tx.send(());
19072+
});
19073+
19074+
rx.recv_timeout(Duration::from_secs(5))
19075+
.expect("parser should reject this quickly, not loop exponentially");
19076+
}
19077+
19078+
/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
19079+
/// like `case-case-case-...c`. Each `case` token triggers a speculative
19080+
/// `parse_case_expr` that fails, but the unreserved-word fallback returns
19081+
/// `Identifier(case)`, so the outer failure cache never fires. Post-fix the
19082+
/// per-arm cache short-circuits the speculative descent.
19083+
#[test]
19084+
fn parse_prefix_case_chain_no_exponential_blowup() {
19085+
use std::sync::mpsc;
19086+
use std::thread;
19087+
use std::time::Duration;
19088+
19089+
let sql = "case\t-".repeat(30) + "c";
19090+
19091+
let (tx, rx) = mpsc::channel();
19092+
thread::spawn(move || {
19093+
let _ = Parser::parse_sql(&SQLiteDialect {}, &sql);
19094+
let _ = tx.send(());
19095+
});
19096+
19097+
rx.recv_timeout(Duration::from_secs(5))
19098+
.expect("parser should reject this quickly, not loop exponentially");
19099+
}

0 commit comments

Comments
 (0)