Skip to content

Commit b7529a2

Browse files
Snowflake: parse FOR-over-cursor loop (FOR row IN cursor DO)
1 parent db7c83f commit b7529a2

6 files changed

Lines changed: 114 additions & 33 deletions

File tree

src/ast/mod.rs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,32 +2790,57 @@ impl fmt::Display for WhileStatement {
27902790
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27912791
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
27922792
pub struct ForStatement {
2793-
/// Loop counter variable.
2793+
/// Loop counter (range form) or row variable (cursor form).
27942794
pub var: Ident,
2795-
/// `true` when the loop iterates from `end` down to `start`.
2796-
pub reverse: bool,
2797-
/// Inclusive lower bound (or upper bound when `reverse`).
2798-
pub start: Expr,
2799-
/// Inclusive upper bound (or lower bound when `reverse`).
2800-
pub end: Expr,
2795+
/// What the loop iterates over.
2796+
pub iteration: ForIterationSource,
28012797
/// Loop body statements.
28022798
pub body: ConditionalStatements,
28032799
}
28042800

2801+
/// The thing a [`ForStatement`] iterates over.
2802+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2803+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2804+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2805+
pub enum ForIterationSource {
2806+
/// Numeric range: `[REVERSE] <start> TO <end>`.
2807+
Range {
2808+
/// `true` when the loop iterates from `end` down to `start`.
2809+
reverse: bool,
2810+
/// Inclusive lower bound (or upper bound when `reverse`).
2811+
start: Expr,
2812+
/// Inclusive upper bound (or lower bound when `reverse`).
2813+
end: Expr,
2814+
},
2815+
/// Iterate over a cursor or query result set: `FOR row IN cursor DO`.
2816+
///
2817+
/// [Snowflake](https://docs.snowflake.com/en/developer-guide/snowflake-scripting/cursors)
2818+
Cursor(Expr),
2819+
}
2820+
28052821
impl fmt::Display for ForStatement {
28062822
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28072823
let ForStatement {
28082824
var,
2809-
reverse,
2810-
start,
2811-
end,
2825+
iteration,
28122826
body,
28132827
} = self;
28142828
write!(f, "FOR {var} IN ")?;
2815-
if *reverse {
2816-
write!(f, "REVERSE ")?;
2829+
match iteration {
2830+
ForIterationSource::Range {
2831+
reverse,
2832+
start,
2833+
end,
2834+
} => {
2835+
if *reverse {
2836+
write!(f, "REVERSE ")?;
2837+
}
2838+
write!(f, "{start} TO {end} DO")?;
2839+
}
2840+
ForIterationSource::Cursor(source) => {
2841+
write!(f, "{source} DO")?;
2842+
}
28172843
}
2818-
write!(f, "{start} TO {end} DO")?;
28192844
if !body.statements().is_empty() {
28202845
write!(f, " {body}")?;
28212846
}

src/ast/spans.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ use super::{
3434
ColumnOption, ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements,
3535
ConflictTarget, ConnectByKind, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
3636
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
37-
ExprWithAlias, Fetch, ForStatement, ForValues, FromTable, Function, FunctionArg,
37+
ExprWithAlias, Fetch, ForIterationSource, ForStatement, ForValues, FromTable, Function,
38+
FunctionArg,
3839
FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr,
3940
HavingBound, IfStatement, IlikeSelectItem, IndexColumn, Insert, Interpolate, InterpolateExpr,
4041
Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause,
@@ -798,16 +799,18 @@ impl Spanned for ForStatement {
798799
fn span(&self) -> Span {
799800
let ForStatement {
800801
var,
801-
reverse: _,
802-
start,
803-
end,
802+
iteration,
804803
body,
805804
} = self;
806-
union_spans(
807-
[var.span, start.span(), end.span(), body.span()]
808-
.into_iter()
809-
.filter(|s| s != &Span::empty()),
810-
)
805+
let mut spans = vec![var.span, body.span()];
806+
match iteration {
807+
ForIterationSource::Range { start, end, .. } => {
808+
spans.push(start.span());
809+
spans.push(end.span());
810+
}
811+
ForIterationSource::Cursor(source) => spans.push(source.span()),
812+
}
813+
union_spans(spans.into_iter().filter(|s| s != &Span::empty()))
811814
}
812815
}
813816

src/dialect/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ pub trait Dialect: Debug + Any {
425425
false
426426
}
427427

428+
/// Returns true if the dialect supports the cursor-iteration `FOR` loop
429+
/// (`FOR <row> IN <cursor> DO ... END FOR`) in addition to the numeric
430+
/// range form.
431+
fn supports_for_loop_over_cursor(&self) -> bool {
432+
false
433+
}
434+
428435
/// Returns true if the dialect supports the MATCH_RECOGNIZE operation.
429436
fn supports_match_recognize(&self) -> bool {
430437
false

src/dialect/snowflake.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ impl Dialect for SnowflakeDialect {
190190
true
191191
}
192192

193+
/// See <https://docs.snowflake.com/en/developer-guide/snowflake-scripting/cursors>
194+
fn supports_for_loop_over_cursor(&self) -> bool {
195+
true
196+
}
197+
193198
fn supports_match_recognize(&self) -> bool {
194199
true
195200
}

src/parser/mod.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -873,18 +873,28 @@ impl<'a> Parser<'a> {
873873
let var = self.parse_identifier()?;
874874
self.expect_keyword_is(Keyword::IN)?;
875875
let reverse = self.parse_keyword(Keyword::REVERSE);
876-
let start = self.parse_expr()?;
877-
self.expect_keyword_is(Keyword::TO)?;
878-
let end = self.parse_expr()?;
876+
let expr = self.parse_expr()?;
877+
let iteration = if !reverse
878+
&& self.dialect.supports_for_loop_over_cursor()
879+
&& !self.peek_keyword(Keyword::TO)
880+
{
881+
ForIterationSource::Cursor(expr)
882+
} else {
883+
self.expect_keyword_is(Keyword::TO)?;
884+
let end = self.parse_expr()?;
885+
ForIterationSource::Range {
886+
reverse,
887+
start: expr,
888+
end,
889+
}
890+
};
879891
self.expect_keyword_is(Keyword::DO)?;
880892
let body = self.parse_scripting_conditional_statements(&[Keyword::END])?;
881893
self.expect_keyword_is(Keyword::END)?;
882894
self.expect_keyword_is(Keyword::FOR)?;
883895
Ok(ForStatement {
884896
var,
885-
reverse,
886-
start,
887-
end,
897+
iteration,
888898
body,
889899
})
890900
}

tests/sqlparser_snowflake.rs

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6751,9 +6751,10 @@ fn test_while_loop_end_loop() {
67516751
fn test_for_to_end_for() {
67526752
let sql = "FOR counter IN 1 TO 10 DO RETURN 1; END FOR";
67536753
match snowflake().verified_stmt(sql) {
6754-
Statement::For(ForStatement { reverse, .. }) => {
6755-
assert!(!reverse);
6756-
}
6754+
Statement::For(ForStatement { iteration, .. }) => match iteration {
6755+
ForIterationSource::Range { reverse, .. } => assert!(!reverse),
6756+
other => panic!("expected Range, got {other:?}"),
6757+
},
67576758
_ => unreachable!(),
67586759
}
67596760
}
@@ -6762,13 +6763,43 @@ fn test_for_to_end_for() {
67626763
fn test_for_reverse_to_end_for() {
67636764
let sql = "FOR counter IN REVERSE 1 TO 10 DO RETURN 1; END FOR";
67646765
match snowflake().verified_stmt(sql) {
6765-
Statement::For(ForStatement { reverse, .. }) => {
6766-
assert!(reverse);
6766+
Statement::For(ForStatement { iteration, .. }) => match iteration {
6767+
ForIterationSource::Range { reverse, .. } => assert!(reverse),
6768+
other => panic!("expected Range, got {other:?}"),
6769+
},
6770+
_ => unreachable!(),
6771+
}
6772+
}
6773+
6774+
#[test]
6775+
fn test_for_in_cursor_end_for() {
6776+
let sql = "FOR rec IN cur DO RETURN rec.price; END FOR";
6777+
match snowflake().verified_stmt(sql) {
6778+
Statement::For(ForStatement {
6779+
var, iteration, ..
6780+
}) => {
6781+
assert_eq!(var.value, "rec");
6782+
match iteration {
6783+
ForIterationSource::Cursor(source) => {
6784+
assert_eq!(source, Expr::Identifier(Ident::new("cur")));
6785+
}
6786+
other => panic!("expected Cursor, got {other:?}"),
6787+
}
67676788
}
67686789
_ => unreachable!(),
67696790
}
67706791
}
67716792

6793+
#[test]
6794+
fn test_for_in_cursor_not_supported_in_other_dialects() {
6795+
let sql = "FOR rec IN cur DO RETURN 1; END FOR";
6796+
let res = TestedDialects::new(vec![Box::new(GenericDialect {})]).parse_sql_statements(sql);
6797+
assert_eq!(
6798+
res.unwrap_err(),
6799+
ParserError::ParserError("Expected: TO, found: DO".to_string())
6800+
);
6801+
}
6802+
67726803
#[test]
67736804
fn test_loop_end_loop() {
67746805
let sql = "LOOP RETURN 1; END LOOP";

0 commit comments

Comments
 (0)