Skip to content

Commit 4b4a9d7

Browse files
authored
Snowflake: Lambda functions (apache#2192)
1 parent 8e36e8e commit 4b4a9d7

File tree

5 files changed

+114
-18
lines changed

5 files changed

+114
-18
lines changed

src/ast/mod.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ impl fmt::Display for AccessExpr {
14231423
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
14241424
pub struct LambdaFunction {
14251425
/// The parameters to the lambda function.
1426-
pub params: OneOrManyWithParens<Ident>,
1426+
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
14271427
/// The body of the lambda function.
14281428
pub body: Box<Expr>,
14291429
/// The syntax style used to write the lambda function.
@@ -1448,6 +1448,27 @@ impl fmt::Display for LambdaFunction {
14481448
}
14491449
}
14501450

1451+
/// A parameter to a lambda function, optionally with a data type.
1452+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
1453+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1454+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
1455+
pub struct LambdaFunctionParameter {
1456+
/// The name of the parameter
1457+
pub name: Ident,
1458+
/// The optional data type of the parameter
1459+
/// [Snowflake Syntax](https://docs.snowflake.com/en/sql-reference/functions/filter#arguments)
1460+
pub data_type: Option<DataType>,
1461+
}
1462+
1463+
impl fmt::Display for LambdaFunctionParameter {
1464+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1465+
match &self.data_type {
1466+
Some(dt) => write!(f, "{} {}", self.name, dt),
1467+
None => write!(f, "{}", self.name),
1468+
}
1469+
}
1470+
}
1471+
14511472
/// The syntax style for a lambda function.
14521473
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)]
14531474
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/dialect/snowflake.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ impl Dialect for SnowflakeDialect {
662662
fn supports_select_wildcard_rename(&self) -> bool {
663663
true
664664
}
665+
666+
/// See <https://docs.snowflake.com/en/user-guide/querying-semistructured#label-higher-order-functions>
667+
fn supports_lambda_functions(&self) -> bool {
668+
true
669+
}
665670
}
666671

667672
// Peeks ahead to identify tokens that are expected after

src/parser/mod.rs

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,10 +1606,34 @@ impl<'a> Parser<'a> {
16061606
value: self.parse_introduced_string_expr()?.into(),
16071607
})
16081608
}
1609+
// An unreserved word (likely an identifier) is followed by an arrow,
1610+
// which indicates a lambda function with a single, untyped parameter.
1611+
// For example: `a -> a * 2`.
16091612
Token::Arrow if self.dialect.supports_lambda_functions() => {
16101613
self.expect_token(&Token::Arrow)?;
16111614
Ok(Expr::Lambda(LambdaFunction {
1612-
params: OneOrManyWithParens::One(w.to_ident(w_span)),
1615+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
1616+
name: w.to_ident(w_span),
1617+
data_type: None,
1618+
}),
1619+
body: Box::new(self.parse_expr()?),
1620+
syntax: LambdaSyntax::Arrow,
1621+
}))
1622+
}
1623+
// An unreserved word (likely an identifier) that is followed by another word (likley a data type)
1624+
// which is then followed by an arrow, which indicates a lambda function with a single, typed parameter.
1625+
// For example: `a INT -> a * 2`.
1626+
Token::Word(_)
1627+
if self.dialect.supports_lambda_functions()
1628+
&& self.peek_nth_token_ref(1).token == Token::Arrow =>
1629+
{
1630+
let data_type = self.parse_data_type()?;
1631+
self.expect_token(&Token::Arrow)?;
1632+
Ok(Expr::Lambda(LambdaFunction {
1633+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
1634+
name: w.to_ident(w_span),
1635+
data_type: Some(data_type),
1636+
}),
16131637
body: Box::new(self.parse_expr()?),
16141638
syntax: LambdaSyntax::Arrow,
16151639
}))
@@ -2195,7 +2219,7 @@ impl<'a> Parser<'a> {
21952219
return Ok(None);
21962220
}
21972221
self.maybe_parse(|p| {
2198-
let params = p.parse_comma_separated(|p| p.parse_identifier())?;
2222+
let params = p.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
21992223
p.expect_token(&Token::RParen)?;
22002224
p.expect_token(&Token::Arrow)?;
22012225
let expr = p.parse_expr()?;
@@ -2207,7 +2231,7 @@ impl<'a> Parser<'a> {
22072231
})
22082232
}
22092233

2210-
/// Parses a lambda expression using the `LAMBDA` keyword syntax.
2234+
/// Parses a lambda expression following the `LAMBDA` keyword syntax.
22112235
///
22122236
/// Syntax: `LAMBDA <params> : <expr>`
22132237
///
@@ -2217,30 +2241,49 @@ impl<'a> Parser<'a> {
22172241
///
22182242
/// See <https://duckdb.org/docs/stable/sql/functions/lambda>
22192243
fn parse_lambda_expr(&mut self) -> Result<Expr, ParserError> {
2244+
// Parse the parameters: either a single identifier or comma-separated identifiers
2245+
let params = self.parse_lambda_function_parameters()?;
2246+
// Expect the colon separator
2247+
self.expect_token(&Token::Colon)?;
2248+
// Parse the body expression
2249+
let body = self.parse_expr()?;
2250+
Ok(Expr::Lambda(LambdaFunction {
2251+
params,
2252+
body: Box::new(body),
2253+
syntax: LambdaSyntax::LambdaKeyword,
2254+
}))
2255+
}
2256+
2257+
/// Parses the parameters of a lambda function with optional typing.
2258+
fn parse_lambda_function_parameters(
2259+
&mut self,
2260+
) -> Result<OneOrManyWithParens<LambdaFunctionParameter>, ParserError> {
22202261
// Parse the parameters: either a single identifier or comma-separated identifiers
22212262
let params = if self.consume_token(&Token::LParen) {
22222263
// Parenthesized parameters: (x, y)
2223-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2264+
let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
22242265
self.expect_token(&Token::RParen)?;
22252266
OneOrManyWithParens::Many(params)
22262267
} else {
22272268
// Unparenthesized parameters: x or x, y
2228-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2269+
let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
22292270
if params.len() == 1 {
22302271
OneOrManyWithParens::One(params.into_iter().next().unwrap())
22312272
} else {
22322273
OneOrManyWithParens::Many(params)
22332274
}
22342275
};
2235-
// Expect the colon separator
2236-
self.expect_token(&Token::Colon)?;
2237-
// Parse the body expression
2238-
let body = self.parse_expr()?;
2239-
Ok(Expr::Lambda(LambdaFunction {
2240-
params,
2241-
body: Box::new(body),
2242-
syntax: LambdaSyntax::LambdaKeyword,
2243-
}))
2276+
Ok(params)
2277+
}
2278+
2279+
/// Parses a single parameter of a lambda function, with optional typing.
2280+
fn parse_lambda_function_parameter(&mut self) -> Result<LambdaFunctionParameter, ParserError> {
2281+
let name = self.parse_identifier()?;
2282+
let data_type = match self.peek_token().token {
2283+
Token::Word(_) => self.maybe_parse(|p| p.parse_data_type())?,
2284+
_ => None,
2285+
};
2286+
Ok(LambdaFunctionParameter { name, data_type })
22442287
}
22452288

22462289
/// Tries to parse the body of an [ODBC escaping sequence]

tests/sqlparser_common.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15925,7 +15925,16 @@ fn test_lambdas() {
1592515925
]
1592615926
),
1592715927
Expr::Lambda(LambdaFunction {
15928-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
15928+
params: OneOrManyWithParens::Many(vec![
15929+
LambdaFunctionParameter {
15930+
name: Ident::new("p1"),
15931+
data_type: None
15932+
},
15933+
LambdaFunctionParameter {
15934+
name: Ident::new("p2"),
15935+
data_type: None
15936+
}
15937+
]),
1592915938
body: Box::new(Expr::Case {
1593015939
case_token: AttachedToken::empty(),
1593115940
end_token: AttachedToken::empty(),
@@ -15970,6 +15979,12 @@ fn test_lambdas() {
1597015979
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
1597115980
);
1597215981
dialects.verified_expr("transform(array(1, 2, 3), x -> x + 1)");
15982+
15983+
// Ensure all lambda variants are parsed correctly
15984+
dialects.verified_expr("a -> a * 2"); // Single parameter without type
15985+
dialects.verified_expr("a INT -> a * 2"); // Single parameter with type
15986+
dialects.verified_expr("(a, b) -> a * b"); // Multiple parameters without types
15987+
dialects.verified_expr("(a INT, b FLOAT) -> a * b"); // Multiple parameters with types
1597315988
}
1597415989

1597515990
#[test]

tests/sqlparser_databricks.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ fn test_databricks_exists() {
7272
]
7373
),
7474
Expr::Lambda(LambdaFunction {
75-
params: OneOrManyWithParens::One(Ident::new("x")),
75+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
76+
name: Ident::new("x"),
77+
data_type: None
78+
}),
7679
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))),
7780
syntax: LambdaSyntax::Arrow,
7881
})
@@ -109,7 +112,16 @@ fn test_databricks_lambdas() {
109112
]
110113
),
111114
Expr::Lambda(LambdaFunction {
112-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
115+
params: OneOrManyWithParens::Many(vec![
116+
LambdaFunctionParameter {
117+
name: Ident::new("p1"),
118+
data_type: None
119+
},
120+
LambdaFunctionParameter {
121+
name: Ident::new("p2"),
122+
data_type: None
123+
}
124+
]),
113125
body: Box::new(Expr::Case {
114126
case_token: AttachedToken::empty(),
115127
end_token: AttachedToken::empty(),

0 commit comments

Comments
 (0)