Skip to content
Merged
6 changes: 3 additions & 3 deletions crates/pgt_statement_splitter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
pub mod diagnostics;
mod parser;

use parser::{Parse, Parser, source};
use parser::{Parser, ParserResult, source};
use pgt_lexer::diagnostics::ScanError;

pub fn split(sql: &str) -> Result<Parse, Vec<ScanError>> {
pub fn split(sql: &str) -> Result<ParserResult, Vec<ScanError>> {
let tokens = pgt_lexer::lex(sql)?;

let mut parser = Parser::new(tokens);
Expand All @@ -28,7 +28,7 @@ mod tests {

struct Tester {
input: String,
parse: Parse,
parse: ParserResult,
}

impl From<&str> for Tester {
Expand Down
235 changes: 123 additions & 112 deletions crates/pgt_statement_splitter/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@ use crate::diagnostics::SplitDiagnostic;
/// Main parser that exposes the `cstree` api, and collects errors and statements
/// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
pub struct Parser {
/// The ranges of the statements
ranges: Vec<(usize, usize)>,
/// The statement ranges are defined by the indices of the start/end tokens
stmt_ranges: Vec<(usize, usize)>,

/// The syntax errors accumulated during parsing
errors: Vec<SplitDiagnostic>,
/// The start of the current statement, if any

current_stmt_start: Option<usize>,
/// The tokens to parse
pub tokens: Vec<Token>,

tokens: Vec<Token>,

eof_token: Token,

next_pos: usize,
current_pos: usize,
}

/// Result of Building
#[derive(Debug)]
pub struct Parse {
/// The ranges of the errors
pub struct ParserResult {
/// The ranges of the parsed statements
pub ranges: Vec<TextRange>,
/// The syntax errors accumulated during parsing
pub errors: Vec<SplitDiagnostic>,
Expand All @@ -41,40 +41,34 @@ impl Parser {
let eof_token = Token::eof(usize::from(
tokens
.last()
.map(|t| t.span.start())
.map(|t| t.span.end())
.unwrap_or(TextSize::from(0)),
));

// next_pos should be the initialised with the first valid token already
let mut next_pos = 0;
loop {
let token = tokens.get(next_pos).unwrap_or(&eof_token);

if is_irrelevant_token(token) {
next_pos += 1;
} else {
break;
}
// Place `current_pos` on the first relevant token
let mut current_pos = 0;
while is_irrelevant_token(tokens.get(current_pos).unwrap_or(&eof_token)) {
current_pos += 1;
}

Self {
ranges: Vec::new(),
stmt_ranges: Vec::new(),
eof_token,
errors: Vec::new(),
current_stmt_start: None,
tokens,
next_pos,
current_pos,
}
}

pub fn finish(self) -> Parse {
Parse {
pub fn finish(self) -> ParserResult {
ParserResult {
ranges: self
.ranges
.stmt_ranges
.iter()
.map(|(start, end)| {
let from = self.tokens.get(*start);
let to = self.tokens.get(*end).unwrap_or(&self.eof_token);
.map(|(start_token_pos, end_token_pos)| {
let from = self.tokens.get(*start_token_pos);
let to = self.tokens.get(*end_token_pos).unwrap_or(&self.eof_token);

TextRange::new(from.unwrap().span.start(), to.span.end())
})
Expand All @@ -83,124 +77,87 @@ impl Parser {
}
}

/// Start statement
pub fn start_stmt(&mut self) {
assert!(
self.current_stmt_start.is_none(),
"cannot start statement within statement at {:?}",
self.tokens.get(self.current_stmt_start.unwrap())
);
self.current_stmt_start = Some(self.next_pos);
self.current_stmt_start = Some(self.current_pos);
}

/// Close statement
pub fn close_stmt(&mut self) {
assert!(self.next_pos > 0);

// go back the positions until we find the first relevant token
let mut end_token_pos = self.next_pos - 1;
loop {
let token = self.tokens.get(end_token_pos);
assert!(
self.current_stmt_start.is_some(),
"Must start statement before closing it."
);

if end_token_pos == 0 || token.is_none() {
break;
}
let start_token_pos = self.current_stmt_start.unwrap();

if !is_irrelevant_token(token.unwrap()) {
break;
}
assert!(
self.current_pos > start_token_pos,
"Must close the statement on a token that's later than the start token."
);

end_token_pos -= 1;
}
let (end_token_pos, _) = self.find_last_relevant().unwrap();

self.ranges.push((
self.current_stmt_start.expect("Expected active statement"),
end_token_pos,
));
self.stmt_ranges.push((start_token_pos, end_token_pos));

self.current_stmt_start = None;
}

fn advance(&mut self) -> &Token {
let mut first_relevant_token = None;
loop {
let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token);

// we need to continue with next_pos until the next relevant token after we already
// found the first one
if !is_irrelevant_token(token) {
if let Some(t) = first_relevant_token {
return t;
}
first_relevant_token = Some(token);
}

self.next_pos += 1;
}
}

fn peek(&self) -> &Token {
match self.tokens.get(self.next_pos) {
fn current(&self) -> &Token {
match self.tokens.get(self.current_pos) {
Some(token) => token,
None => &self.eof_token,
}
}

/// Look ahead to the next relevant token
/// Returns `None` if we are already at the last relevant token
fn look_ahead(&self) -> Option<&Token> {
// we need to look ahead to the next relevant token
let mut look_ahead_pos = self.next_pos + 1;
loop {
let token = self.tokens.get(look_ahead_pos)?;

if !is_irrelevant_token(token) {
return Some(token);
}
fn advance(&mut self) -> &Token {
// can't reuse any `find_next_relevant` logic because of Mr. Borrow Checker
let (pos, token) = self
.tokens
.iter()
.enumerate()
.skip(self.current_pos + 1)
.find(|(_, t)| is_relevant(t))
.unwrap_or((self.tokens.len(), &self.eof_token));

self.current_pos = pos;
token
}

look_ahead_pos += 1;
}
fn look_ahead(&self) -> Option<&Token> {
self.tokens
.iter()
.skip(self.current_pos + 1)
.find(|t| is_relevant(t))
}

/// Returns `None` if there are no previous relevant tokens
fn look_back(&self) -> Option<&Token> {
// we need to look back to the last relevant token
let mut look_back_pos = self.next_pos - 1;
loop {
let token = self.tokens.get(look_back_pos);

if look_back_pos == 0 || token.is_none() {
return None;
}

if !is_irrelevant_token(token.unwrap()) {
return token;
}

look_back_pos -= 1;
}
self.find_last_relevant().map(|it| it.1)
}

/// checks if the current token is of `kind` and advances if true
/// returns true if the current token is of `kind`
pub fn eat(&mut self, kind: SyntaxKind) -> bool {
if self.peek().kind == kind {
/// Will advance if the `kind` matches the current token.
/// Otherwise, will add a diagnostic to the internal `errors`.
pub fn expect(&mut self, kind: SyntaxKind) {
if self.current().kind == kind {
self.advance();
true
} else {
false
self.errors.push(SplitDiagnostic::new(
format!("Expected {:#?}", kind),
self.current().span,
));
}
}

pub fn expect(&mut self, kind: SyntaxKind) {
if self.eat(kind) {
return;
}

self.errors.push(SplitDiagnostic::new(
format!("Expected {:#?}", kind),
self.peek().span,
));
fn find_last_relevant(&self) -> Option<(usize, &Token)> {
self.tokens
.iter()
.enumerate()
.take(self.current_pos)
.rfind(|(_, t)| is_relevant(t))
}
}

Expand All @@ -219,3 +176,57 @@ fn is_irrelevant_token(t: &Token) -> bool {
WHITESPACE_TOKENS.contains(&t.kind)
&& (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1)
}

fn is_relevant(t: &Token) -> bool {
!is_irrelevant_token(t)
}

#[cfg(test)]
mod tests {
use pgt_lexer::SyntaxKind;

use crate::parser::Parser;

#[test]
fn advance_works_as_expected() {
let sql = r#"
create table users (
id serial primary key,
name text,
email text
);
"#;
let tokens = pgt_lexer::lex(sql).unwrap();
let total_num_tokens = tokens.len();

let mut parser = Parser::new(tokens);

let expected = vec![
(SyntaxKind::Create, 2),
(SyntaxKind::Table, 4),
(SyntaxKind::Ident, 6),
(SyntaxKind::Ascii40, 8),
(SyntaxKind::Ident, 11),
(SyntaxKind::Ident, 13),
(SyntaxKind::Primary, 15),
(SyntaxKind::Key, 17),
(SyntaxKind::Ascii44, 18),
(SyntaxKind::NameP, 21),
(SyntaxKind::TextP, 23),
(SyntaxKind::Ascii44, 24),
(SyntaxKind::Ident, 27),
(SyntaxKind::TextP, 29),
(SyntaxKind::Ascii41, 32),
(SyntaxKind::Ascii59, 33),
];

for (kind, pos) in expected {
assert_eq!(parser.current().kind, kind);
assert_eq!(parser.current_pos, pos);
parser.advance();
}

assert_eq!(parser.current().kind, SyntaxKind::Eof);
assert_eq!(parser.current_pos, total_num_tokens);
}
}
10 changes: 5 additions & 5 deletions crates/pgt_statement_splitter/src/parser/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{

pub fn source(p: &mut Parser) {
loop {
match p.peek() {
match p.current() {
Token {
kind: SyntaxKind::Eof,
..
Expand All @@ -33,7 +33,7 @@ pub fn source(p: &mut Parser) {

pub(crate) fn statement(p: &mut Parser) {
p.start_stmt();
match p.peek().kind {
match p.current().kind {
SyntaxKind::With => {
cte(p);
}
Expand Down Expand Up @@ -68,7 +68,7 @@ pub(crate) fn parenthesis(p: &mut Parser) {
let mut depth = 1;

loop {
match p.peek().kind {
match p.current().kind {
SyntaxKind::Ascii40 => {
p.advance();
depth += 1;
Expand All @@ -91,7 +91,7 @@ pub(crate) fn case(p: &mut Parser) {
p.expect(SyntaxKind::Case);

loop {
match p.peek().kind {
match p.current().kind {
SyntaxKind::EndP => {
p.advance();
break;
Expand All @@ -105,7 +105,7 @@ pub(crate) fn case(p: &mut Parser) {

pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
loop {
match p.peek() {
match p.current() {
Token {
kind: SyntaxKind::Ascii59,
..
Expand Down
4 changes: 3 additions & 1 deletion crates/pgt_statement_splitter/src/parser/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ pub(crate) fn cte(p: &mut Parser) {
p.expect(SyntaxKind::As);
parenthesis(p);

if !p.eat(SyntaxKind::Ascii44) {
if p.current().kind == SyntaxKind::Ascii44 {
p.advance();
} else {
break;
}
}
Expand Down
Loading