Skip to content

Commit fbbffea

Browse files
author
Roman Borschel
committed
Add support for MSSQL IF/ELSE statements.
These are syntactically quite different from the already supported IF ... THEN ... ELSEIF ... END IF statements. Hence IfStatement is now an enum with two variants and statement parsing is overridden for the MSSQL dialect in order to parse IF statements differently for MSSQL. Thereby fix spans for if/case AST nodes by including start/end tokens, if present.
1 parent 776b10a commit fbbffea

File tree

6 files changed

+508
-134
lines changed

6 files changed

+508
-134
lines changed

src/ast/mod.rs

Lines changed: 140 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ use serde::{Deserialize, Serialize};
3737
#[cfg(feature = "visitor")]
3838
use sqlparser_derive::{Visit, VisitMut};
3939

40-
use crate::tokenizer::Span;
40+
use crate::keywords::Keyword;
41+
use crate::tokenizer::{Span, Token, TokenWithSpan};
4142

4243
pub use self::data_type::{
4344
ArrayElemTypeDef, BinaryLength, CharLengthUnits, CharacterLength, DataType, EnumMember,
@@ -2118,20 +2119,23 @@ pub enum Password {
21182119
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21192120
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
21202121
pub struct CaseStatement {
2122+
/// The `CASE` token that starts the statement.
2123+
pub case_token: TokenWithSpan,
21212124
pub match_expr: Option<Expr>,
21222125
pub when_blocks: Vec<ConditionalStatements>,
2123-
pub else_block: Option<Vec<Statement>>,
2124-
/// TRUE if the statement ends with `END CASE` (vs `END`).
2125-
pub has_end_case: bool,
2126+
pub else_block: Option<ConditionalStatements>,
2127+
/// The last token of the statement (`END` or `CASE`).
2128+
pub end_case_token: TokenWithSpan,
21262129
}
21272130

21282131
impl fmt::Display for CaseStatement {
21292132
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21302133
let CaseStatement {
2134+
case_token: _,
21312135
match_expr,
21322136
when_blocks,
21332137
else_block,
2134-
has_end_case,
2138+
end_case_token,
21352139
} = self;
21362140

21372141
write!(f, "CASE")?;
@@ -2145,116 +2149,189 @@ impl fmt::Display for CaseStatement {
21452149
}
21462150

21472151
if let Some(else_block) = else_block {
2148-
write!(f, " ELSE ")?;
2149-
format_statement_list(f, else_block)?;
2152+
write!(f, " {else_block}")?;
21502153
}
21512154

21522155
write!(f, " END")?;
2153-
if *has_end_case {
2154-
write!(f, " CASE")?;
2156+
2157+
if let Token::Word(w) = &end_case_token.token {
2158+
if w.keyword == Keyword::CASE {
2159+
write!(f, " CASE")?;
2160+
}
21552161
}
21562162

21572163
Ok(())
21582164
}
21592165
}
21602166

21612167
/// An `IF` statement.
2162-
///
2163-
/// Examples:
2164-
/// ```sql
2165-
/// IF TRUE THEN
2166-
/// SELECT 1;
2167-
/// SELECT 2;
2168-
/// ELSEIF TRUE THEN
2169-
/// SELECT 3;
2170-
/// ELSE
2171-
/// SELECT 4;
2172-
/// END IF
2173-
/// ```
2174-
///
2175-
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
2176-
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
21772168
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
21782169
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21792170
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2180-
pub struct IfStatement {
2181-
pub if_block: ConditionalStatements,
2182-
pub elseif_blocks: Vec<ConditionalStatements>,
2183-
pub else_block: Option<Vec<Statement>>,
2171+
pub enum IfStatement {
2172+
/// An `IF ... THEN [ELSE[IF] ...] END IF` statement.
2173+
///
2174+
/// Example:
2175+
/// ```sql
2176+
/// IF TRUE THEN
2177+
/// SELECT 1;
2178+
/// SELECT 2;
2179+
/// ELSEIF TRUE THEN
2180+
/// SELECT 3;
2181+
/// ELSE
2182+
/// SELECT 4;
2183+
/// END IF
2184+
/// ```
2185+
///
2186+
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
2187+
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
2188+
IfThenElseEnd {
2189+
/// The `IF` token that starts the statement.
2190+
if_token: TokenWithSpan,
2191+
if_block: ConditionalStatements,
2192+
elseif_blocks: Vec<ConditionalStatements>,
2193+
else_block: Option<ConditionalStatements>,
2194+
/// The `IF` token that ends the statement.
2195+
end_if_token: TokenWithSpan,
2196+
},
2197+
/// An MSSQL `IF ... ELSE ...` statement.
2198+
///
2199+
/// Example:
2200+
/// ```sql
2201+
/// IF 1=1 SELECT 1 ELSE SELECT 2
2202+
/// ```
2203+
///
2204+
/// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16)
2205+
MsSqlIfElse {
2206+
if_token: TokenWithSpan,
2207+
condition: Expr,
2208+
if_statements: MsSqlIfStatements,
2209+
else_statements: Option<MsSqlIfStatements>,
2210+
},
21842211
}
21852212

21862213
impl fmt::Display for IfStatement {
21872214
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2188-
let IfStatement {
2189-
if_block,
2190-
elseif_blocks,
2191-
else_block,
2192-
} = self;
2215+
match self {
2216+
IfStatement::IfThenElseEnd {
2217+
if_token: _,
2218+
if_block,
2219+
elseif_blocks,
2220+
else_block,
2221+
end_if_token: _,
2222+
} => {
2223+
write!(f, "{if_block}")?;
21932224

2194-
write!(f, "{if_block}")?;
2225+
if !elseif_blocks.is_empty() {
2226+
write!(f, " {}", display_separated(elseif_blocks, " "))?;
2227+
}
21952228

2196-
if !elseif_blocks.is_empty() {
2197-
write!(f, " {}", display_separated(elseif_blocks, " "))?;
2198-
}
2229+
if let Some(else_block) = else_block {
2230+
write!(f, " {else_block}")?;
2231+
}
21992232

2200-
if let Some(else_block) = else_block {
2201-
write!(f, " ELSE ")?;
2202-
format_statement_list(f, else_block)?;
2203-
}
2233+
write!(f, " END IF")?;
2234+
2235+
Ok(())
2236+
}
2237+
IfStatement::MsSqlIfElse {
2238+
if_token: _,
2239+
condition,
2240+
if_statements,
2241+
else_statements,
2242+
} => {
2243+
write!(f, "IF {condition} {if_statements}")?;
22042244

2205-
write!(f, " END IF")?;
2245+
if let Some(els) = else_statements {
2246+
write!(f, " ELSE {els}")?;
2247+
}
22062248

2207-
Ok(())
2249+
Ok(())
2250+
}
2251+
}
22082252
}
22092253
}
22102254

2211-
/// Represents a type of [ConditionalStatements]
2255+
/// (MSSQL) Either a single [Statement] or a block of statements
2256+
/// enclosed in `BEGIN` and `END`.
22122257
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
22132258
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22142259
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2215-
pub enum ConditionalStatementKind {
2216-
/// `WHEN <condition> THEN <statements>`
2217-
When,
2218-
/// `IF <condition> THEN <statements>`
2219-
If,
2220-
/// `ELSEIF <condition> THEN <statements>`
2221-
ElseIf,
2260+
pub enum MsSqlIfStatements {
2261+
/// A single statement.
2262+
Single(Box<Statement>),
2263+
/// ```sql
2264+
/// A logical block of statements.
2265+
///
2266+
/// BEGIN
2267+
/// <statement>;
2268+
/// <statement>;
2269+
/// ...
2270+
/// END
2271+
/// ```
2272+
Block {
2273+
begin_token: TokenWithSpan,
2274+
statements: Vec<Statement>,
2275+
end_token: TokenWithSpan,
2276+
},
2277+
}
2278+
2279+
impl fmt::Display for MsSqlIfStatements {
2280+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2281+
match self {
2282+
MsSqlIfStatements::Single(stmt) => stmt.fmt(f),
2283+
MsSqlIfStatements::Block { statements, .. } => {
2284+
write!(f, "BEGIN ")?;
2285+
format_statement_list(f, statements)?;
2286+
write!(f, " END")
2287+
}
2288+
}
2289+
}
22222290
}
22232291

22242292
/// A block within a [Statement::Case] or [Statement::If]-like statement
22252293
///
2226-
/// Examples:
2294+
/// Example 1:
22272295
/// ```sql
22282296
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
2297+
/// ```
22292298
///
2299+
/// Example 2:
2300+
/// ```sql
22302301
/// IF TRUE THEN SELECT 1; SELECT 2;
22312302
/// ```
2303+
///
2304+
/// Example 3:
2305+
/// ```sql
2306+
/// ELSE SELECT 1; SELECT 2;
2307+
/// ```
22322308
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
22332309
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22342310
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
22352311
pub struct ConditionalStatements {
2236-
/// The condition expression.
2237-
pub condition: Expr,
2312+
/// The start token of the conditional (`WHEN`, `IF`, `ELSEIF` or `ELSE`).
2313+
pub start_token: TokenWithSpan,
2314+
/// The condition expression. `None` for `ELSE` statements.
2315+
pub condition: Option<Expr>,
22382316
/// Statement list of the `THEN` clause.
22392317
pub statements: Vec<Statement>,
2240-
pub kind: ConditionalStatementKind,
22412318
}
22422319

22432320
impl fmt::Display for ConditionalStatements {
22442321
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22452322
let ConditionalStatements {
2246-
condition: expr,
2323+
start_token,
2324+
condition,
22472325
statements,
2248-
kind,
22492326
} = self;
22502327

2251-
let kind = match kind {
2252-
ConditionalStatementKind::When => "WHEN",
2253-
ConditionalStatementKind::If => "IF",
2254-
ConditionalStatementKind::ElseIf => "ELSEIF",
2255-
};
2328+
let keyword = &start_token.token;
22562329

2257-
write!(f, "{kind} {expr} THEN")?;
2330+
if let Some(expr) = condition {
2331+
write!(f, "{keyword} {expr} THEN")?;
2332+
} else {
2333+
write!(f, "{keyword}")?;
2334+
}
22582335

22592336
if !statements.is_empty() {
22602337
write!(f, " ")?;

src/ast/spans.rs

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ use super::{
3030
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
3131
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
3232
JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure,
33-
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
34-
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
35-
Query, RaiseStatement, RaiseStatementValue, ReferentialAction, RenameSelectItem,
36-
ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption,
37-
Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint,
38-
TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use,
39-
Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
33+
MsSqlIfStatements, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict,
34+
OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource,
35+
ProjectionSelect, Query, RaiseStatement, RaiseStatementValue, ReferentialAction,
36+
RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem,
37+
SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef,
38+
TableConstraint, TableFactor, TableObject, TableOptionsClustered, TableWithJoins,
39+
UpdateTableFromKind, Use, Value, Values, ViewColumnDef, WildcardAdditionalOptions, With,
40+
WithFill,
4041
};
4142

4243
/// Given an iterator of spans, return the [Span::union] of all spans.
@@ -739,47 +740,63 @@ impl Spanned for CreateIndex {
739740
impl Spanned for CaseStatement {
740741
fn span(&self) -> Span {
741742
let CaseStatement {
742-
match_expr,
743-
when_blocks,
744-
else_block,
745-
has_end_case: _,
743+
case_token,
744+
end_case_token,
745+
..
746746
} = self;
747747

748-
union_spans(
749-
match_expr
750-
.iter()
751-
.map(|e| e.span())
752-
.chain(when_blocks.iter().map(|b| b.span()))
753-
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
754-
)
748+
union_spans([case_token.span, end_case_token.span].into_iter())
755749
}
756750
}
757751

758752
impl Spanned for IfStatement {
759753
fn span(&self) -> Span {
760-
let IfStatement {
761-
if_block,
762-
elseif_blocks,
763-
else_block,
764-
} = self;
754+
match self {
755+
IfStatement::IfThenElseEnd {
756+
if_token,
757+
end_if_token,
758+
..
759+
} => union_spans([if_token.span, end_if_token.span].into_iter()),
760+
IfStatement::MsSqlIfElse {
761+
if_token,
762+
if_statements,
763+
else_statements,
764+
..
765+
} => union_spans(
766+
[if_token.span, if_statements.span()]
767+
.into_iter()
768+
.chain(else_statements.as_ref().into_iter().map(|s| s.span())),
769+
),
770+
}
771+
}
772+
}
765773

766-
union_spans(
767-
iter::once(if_block.span())
768-
.chain(elseif_blocks.iter().map(|b| b.span()))
769-
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
770-
)
774+
impl Spanned for MsSqlIfStatements {
775+
fn span(&self) -> Span {
776+
match self {
777+
MsSqlIfStatements::Single(s) => s.span(),
778+
MsSqlIfStatements::Block {
779+
begin_token,
780+
end_token,
781+
..
782+
} => union_spans([begin_token.span, end_token.span].into_iter()),
783+
}
771784
}
772785
}
773786

774787
impl Spanned for ConditionalStatements {
775788
fn span(&self) -> Span {
776789
let ConditionalStatements {
790+
start_token,
777791
condition,
778792
statements,
779-
kind: _,
780793
} = self;
781794

782-
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
795+
union_spans(
796+
iter::once(start_token.span)
797+
.chain(condition.as_ref().map(|c| c.span()).into_iter())
798+
.chain(statements.iter().map(|s| s.span())),
799+
)
783800
}
784801
}
785802

0 commit comments

Comments
 (0)