Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ mod test {
fn rewrite_prefix_string_not_applicable() {
assert!(code_action_not_applicable(
rewrite_as_dollar_quoted_string,
"select b'foo$0';"
"select b'010$0';"
));
}
}
237 changes: 208 additions & 29 deletions crates/squawk_parser/src/lexed_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,10 @@ impl<'a> Converter<'a> {

squawk_lexer::TokenKind::Whitespace => SyntaxKind::WHITESPACE,
squawk_lexer::TokenKind::Ident => {
// TODO: check for max identifier length
//
// see: https://www.postgresql.org/docs/16/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
// The system uses no more than NAMEDATALEN-1 bytes of an
// identifier; longer names can be written in commands, but
// they will be truncated. By default, NAMEDATALEN is 64 so
// the maximum identifier length is 63 bytes. If this limit
// is problematic, it can be raised by changing the
// NAMEDATALEN constant in src/include/pg_config_manual.h.
// see: https://github.com/postgres/postgres/blob/e032e4c7ddd0e1f7865b246ec18944365d4f8614/src/include/pg_config_manual.h#L29
SyntaxKind::from_keyword(token_text).unwrap_or(SyntaxKind::IDENT)
}
squawk_lexer::TokenKind::Literal { kind, .. } => {
self.extend_literal(token_text.len(), kind);
self.extend_literal(token_text, kind);
return;
}
squawk_lexer::TokenKind::Semi => SyntaxKind::SEMICOLON,
Expand Down Expand Up @@ -233,13 +223,13 @@ impl<'a> Converter<'a> {
self.push(syntax_kind, token_text.len(), err);
}

fn extend_literal(&mut self, len: usize, kind: &squawk_lexer::LiteralKind) {
let mut err = "";
fn extend_literal(&mut self, token_text: &str, kind: &squawk_lexer::LiteralKind) {
let mut err: Option<String> = None;

let syntax_kind = match *kind {
squawk_lexer::LiteralKind::Int { empty_int, base: _ } => {
if empty_int {
err = "Missing digits after the integer base prefix";
err = Some("Missing digits after the integer base prefix".into());
}
SyntaxKind::INT_NUMBER
}
Expand All @@ -248,56 +238,245 @@ impl<'a> Converter<'a> {
base: _,
} => {
if empty_exponent {
err = "Missing digits after the exponent symbol";
err = Some("Missing digits after the exponent symbol".into());
}
SyntaxKind::FLOAT_NUMBER
}
squawk_lexer::LiteralKind::Str { terminated } => {
if !terminated {
err = "Missing trailing `'` symbol to terminate the string literal";
err =
Some("Missing trailing `'` symbol to terminate the string literal".into());
}
// TODO: rust analzyer checks for un-escaped strings, we should too
SyntaxKind::STRING
}
squawk_lexer::LiteralKind::ByteStr { terminated } => {
if !terminated {
err = "Missing trailing `'` symbol to terminate the hex bit string literal";
err = Some(
"Missing trailing `'` symbol to terminate the hex bit string literal"
.into(),
);
} else {
let inside = &token_text[2..token_text.len() - 1];
if let Some(c) = inside.chars().find(|c| !c.is_ascii_hexdigit()) {
err = Some(format!("\"{c}\" is not a valid hexadecimal digit"));
}
}
// TODO: rust analzyer checks for un-escaped strings, we should too
SyntaxKind::BYTE_STRING
}
squawk_lexer::LiteralKind::BitStr { terminated } => {
if !terminated {
err = "Missing trailing `\'` symbol to terminate the bit string literal";
err = Some(
"Missing trailing `'` symbol to terminate the bit string literal".into(),
);
} else {
let inside = &token_text[2..token_text.len() - 1];
if let Some(c) = inside.chars().find(|&c| c != '0' && c != '1') {
err = Some(format!("\"{c}\" is not a valid binary digit"));
}
}
// TODO: rust analzyer checks for un-escaped strings, we should too
SyntaxKind::BIT_STRING
}
squawk_lexer::LiteralKind::DollarQuotedString { terminated } => {
if !terminated {
// TODO: we could be fancier and say the ending string we're looking for
err = "Unterminated dollar quoted string literal";
err = Some("Unterminated dollar quoted string literal".into());
}
// TODO: rust analzyer checks for un-escaped strings, we should too
SyntaxKind::DOLLAR_QUOTED_STRING
}
squawk_lexer::LiteralKind::UnicodeEscStr { terminated } => {
if !terminated {
err = "Missing trailing `'` symbol to terminate the unicode escape string literal";
err = Some(
"Missing trailing `'` symbol to terminate the unicode escape string literal"
.into(),
);
}
// TODO: rust analzyer checks for un-escaped strings, we should too
// validated in squawk_syntax
SyntaxKind::UNICODE_ESC_STRING
}
squawk_lexer::LiteralKind::EscStr { terminated } => {
if !terminated {
err = "Missing trailing `\'` symbol to terminate the escape string literal";
err = Some(
"Missing trailing `'` symbol to terminate the escape string literal".into(),
);
} else {
err = validate_escape_string_unicode_escapes(token_text);
}
// TODO: rust analzyer checks for un-escaped strings, we should too
SyntaxKind::ESC_STRING
}
};

let err = if err.is_empty() { None } else { Some(err) };
self.push(syntax_kind, len, err);
self.push(syntax_kind, token_text.len(), err.as_deref());
}
}

fn validate_escape_string_unicode_escapes(token_text: &str) -> Option<String> {
let mut chars = token_text[2..token_text.len() - 1].chars();

while let Some(c) = chars.next() {
if c != '\\' {
continue;
}

let (required, example) = match chars.next() {
Some('u') => (4, r"\uXXXX"),
Some('U') => (8, r"\UXXXXXXXX"),
_ => continue,
};

for _ in 0..required {
if !chars.next().is_some_and(|c| c.is_ascii_hexdigit()) {
return Some(format!(
"Unicode escape requires {required} hex digits: {example}"
));
}
}
}

None
}

#[cfg(test)]
mod tests {
use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
use insta::assert_snapshot;

use super::LexedStr;

fn lex(text: &str) -> String {
let lexed = LexedStr::new(text);
let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
let mut res = String::new();

for (token, msg) in lexed.errors() {
let group = Level::ERROR.primary_title(msg).element(
Snippet::source(text)
.fold(true)
.annotation(AnnotationKind::Primary.span(lexed.text_range(token))),
);
res.push_str(&renderer.render(&[group]).to_string());
res.push('\n');
}

res
}

#[test]
fn empty_int_error() {
assert_snapshot!(lex("select 0x;"), @"
error: Missing digits after the integer base prefix
╭▸
1 │ select 0x;
╰╴ ━━
");
}

#[test]
fn empty_exponent_error() {
assert_snapshot!(lex("select 1e;"), @"
error: Missing digits after the exponent symbol
╭▸
1 │ select 1e;
╰╴ ━━
");
}

#[test]
fn unterminated_string_error() {
assert_snapshot!(lex("select 'hello;"), @"
error: Missing trailing `'` symbol to terminate the string literal
╭▸
1 │ select 'hello;
╰╴ ━━━━━━━
");
}

#[test]
fn hex_invalid_digit() {
assert_snapshot!(lex("select X'1FZ';"), @r#"
error: "Z" is not a valid hexadecimal digit
╭▸
1 │ select X'1FZ';
╰╴ ━━━━━━
"#);
}

#[test]
fn unterminated_hex_bit_string_error() {
assert_snapshot!(lex("select X'1F;"), @"
error: Missing trailing `'` symbol to terminate the hex bit string literal
╭▸
1 │ select X'1F;
╰╴ ━━━━━
");
}

#[test]
fn unterminated_bit_string_error() {
assert_snapshot!(lex("select B'101;"), @"
error: Missing trailing `'` symbol to terminate the bit string literal
╭▸
1 │ select B'101;
╰╴ ━━━━━━
");
}

#[test]
fn invalid_binary_digit_error() {
assert_snapshot!(lex("select b'0 ';"), @r#"
error: " " is not a valid binary digit
╭▸
1 │ select b'0 ';
╰╴ ━━━━━
"#);
}

#[test]
fn unterminated_dollar_quoted_string_error() {
assert_snapshot!(lex("select $tag$hello;"), @"
error: Unterminated dollar quoted string literal
╭▸
1 │ select $tag$hello;
╰╴ ━━━━━━━━━━━
");
}

#[test]
fn unterminated_unicode_escape_string_error() {
assert_snapshot!(lex("select U&'hello;"), @"
error: Missing trailing `'` symbol to terminate the unicode escape string literal
╭▸
1 │ select U&'hello;
╰╴ ━━━━━━━━━
");
}

#[test]
fn unterminated_escape_string_error() {
assert_snapshot!(lex("select E'hello;"), @"
error: Missing trailing `'` symbol to terminate the escape string literal
╭▸
1 │ select E'hello;
╰╴ ━━━━━━━━
");
}

#[test]
fn invalid_unicode_escape_4_digits_error() {
assert_snapshot!(lex(r"select E'\u00';"), @r"
error: Unicode escape requires 4 hex digits: \uXXXX
╭▸
1 │ select E'\u00';
╰╴ ━━━━━━━
");
}

#[test]
fn invalid_unicode_escape_8_digits_error() {
assert_snapshot!(lex(r"select E'\UFFFF';"), @r"
error: Unicode escape requires 8 hex digits: \UXXXXXXXX
╭▸
1 │ select E'\UFFFF';
╰╴ ━━━━━━━━━
");
}
}
1 change: 1 addition & 0 deletions crates/squawk_syntax/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl<T> Parse<T> {
vec![]
};
validation::validate(&self.syntax_node(), &mut errors);
errors.sort_by_key(|error| error.range().start());
errors
}
}
Expand Down
Loading
Loading