Skip to content

Commit b3e176d

Browse files
authored
Add SETOF support for PostgreSQL function return types (apache#2217)
1 parent 203ced4 commit b3e176d

File tree

7 files changed

+82
-27
lines changed

7 files changed

+82
-27
lines changed

src/ast/ddl.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3533,6 +3533,28 @@ impl fmt::Display for CreateDomain {
35333533
}
35343534
}
35353535

3536+
/// The return type of a `CREATE FUNCTION` statement.
3537+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
3538+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
3539+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
3540+
pub enum FunctionReturnType {
3541+
/// `RETURNS <type>`
3542+
DataType(DataType),
3543+
/// `RETURNS SETOF <type>`
3544+
///
3545+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
3546+
SetOf(DataType),
3547+
}
3548+
3549+
impl fmt::Display for FunctionReturnType {
3550+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3551+
match self {
3552+
FunctionReturnType::DataType(data_type) => write!(f, "{data_type}"),
3553+
FunctionReturnType::SetOf(data_type) => write!(f, "SETOF {data_type}"),
3554+
}
3555+
}
3556+
}
3557+
35363558
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
35373559
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35383560
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -3553,7 +3575,7 @@ pub struct CreateFunction {
35533575
/// List of arguments for the function.
35543576
pub args: Option<Vec<OperateFunctionArg>>,
35553577
/// The return type of the function.
3556-
pub return_type: Option<DataType>,
3578+
pub return_type: Option<FunctionReturnType>,
35573579
/// The expression that defines the function.
35583580
///
35593581
/// Examples:

src/ast/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ pub use self::ddl::{
7272
CreatePolicyCommand, CreatePolicyType, CreateTable, CreateTrigger, CreateView, Deduplicate,
7373
DeferrableInitial, DistStyle, DropBehavior, DropExtension, DropFunction, DropOperator,
7474
DropOperatorClass, DropOperatorFamily, DropOperatorSignature, DropPolicy, DropTrigger,
75-
ForValues, GeneratedAs, GeneratedExpressionMode, IdentityParameters, IdentityProperty,
76-
IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, IndexColumn,
77-
IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, OperatorArgTypes,
78-
OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem, OperatorOption, OperatorPurpose,
79-
Owner, Partition, PartitionBoundValue, ProcedureParam, ReferentialAction, RenameTableNameKind,
80-
ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate,
81-
UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength,
75+
ForValues, FunctionReturnType, GeneratedAs, GeneratedExpressionMode, IdentityParameters,
76+
IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder,
77+
IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption,
78+
OperatorArgTypes, OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem,
79+
OperatorOption, OperatorPurpose, Owner, Partition, PartitionBoundValue, ProcedureParam,
80+
ReferentialAction, RenameTableNameKind, ReplicaIdentity, TagsColumnOption, TriggerObjectKind,
81+
Truncate, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength,
8282
UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, UserDefinedTypeSqlDefinitionOption,
8383
UserDefinedTypeStorage, ViewColumnDef,
8484
};

src/keywords.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ define_keywords!(
938938
SESSION_USER,
939939
SET,
940940
SETERROR,
941+
SETOF,
941942
SETS,
942943
SETTINGS,
943944
SHARE,

src/parser/mod.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5594,7 +5594,7 @@ impl<'a> Parser<'a> {
55945594
self.expect_token(&Token::RParen)?;
55955595

55965596
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5597-
Some(self.parse_data_type()?)
5597+
Some(self.parse_function_return_type()?)
55985598
} else {
55995599
None
56005600
};
@@ -5774,7 +5774,7 @@ impl<'a> Parser<'a> {
57745774
let (name, args) = self.parse_create_function_name_and_params()?;
57755775

57765776
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5777-
Some(self.parse_data_type()?)
5777+
Some(self.parse_function_return_type()?)
57785778
} else {
57795779
None
57805780
};
@@ -5877,11 +5877,11 @@ impl<'a> Parser<'a> {
58775877
})
58785878
})?;
58795879

5880-
let return_type = if return_table.is_some() {
5881-
return_table
5882-
} else {
5883-
Some(self.parse_data_type()?)
5880+
let data_type = match return_table {
5881+
Some(table_type) => table_type,
5882+
None => self.parse_data_type()?,
58845883
};
5884+
let return_type = Some(FunctionReturnType::DataType(data_type));
58855885

58865886
let _ = self.parse_keyword(Keyword::AS);
58875887

@@ -5933,6 +5933,14 @@ impl<'a> Parser<'a> {
59335933
})
59345934
}
59355935

5936+
fn parse_function_return_type(&mut self) -> Result<FunctionReturnType, ParserError> {
5937+
if self.parse_keyword(Keyword::SETOF) {
5938+
Ok(FunctionReturnType::SetOf(self.parse_data_type()?))
5939+
} else {
5940+
Ok(FunctionReturnType::DataType(self.parse_data_type()?))
5941+
}
5942+
}
5943+
59365944
fn parse_create_function_name_and_params(
59375945
&mut self,
59385946
) -> Result<(ObjectName, Vec<OperateFunctionArg>), ParserError> {
@@ -8608,7 +8616,7 @@ impl<'a> Parser<'a> {
86088616
}
86098617
}
86108618

8611-
/// Parse a single [PartitionBoundValue].
8619+
/// Parse a single partition bound value (MINVALUE, MAXVALUE, or expression).
86128620
fn parse_partition_bound_value(&mut self) -> Result<PartitionBoundValue, ParserError> {
86138621
if self.parse_keyword(Keyword::MINVALUE) {
86148622
Ok(PartitionBoundValue::MinValue)

tests/sqlparser_bigquery.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2289,7 +2289,7 @@ fn test_bigquery_create_function() {
22892289
Ident::new("myfunction"),
22902290
]),
22912291
args: Some(vec![OperateFunctionArg::with_name("x", DataType::Float64),]),
2292-
return_type: Some(DataType::Float64),
2292+
return_type: Some(FunctionReturnType::DataType(DataType::Float64)),
22932293
function_body: Some(CreateFunctionBody::AsAfterOptions(Expr::Value(
22942294
number("42").with_empty_span()
22952295
))),

tests/sqlparser_mssql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ fn parse_create_function() {
255255
default_expr: None,
256256
},
257257
]),
258-
return_type: Some(DataType::Int(None)),
258+
return_type: Some(FunctionReturnType::DataType(DataType::Int(None))),
259259
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
260260
begin_token: AttachedToken::empty(),
261261
statements: vec![Statement::Return(ReturnStatement {
@@ -430,7 +430,7 @@ fn parse_create_function_parameter_default_values() {
430430
data_type: DataType::Int(None),
431431
default_expr: Some(Expr::Value((number("42")).with_empty_span())),
432432
},]),
433-
return_type: Some(DataType::Int(None)),
433+
return_type: Some(FunctionReturnType::DataType(DataType::Int(None))),
434434
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
435435
begin_token: AttachedToken::empty(),
436436
statements: vec![Statement::Return(ReturnStatement {

tests/sqlparser_postgres.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4441,7 +4441,7 @@ $$"#;
44414441
DataType::Varchar(None),
44424442
),
44434443
]),
4444-
return_type: Some(DataType::Boolean),
4444+
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
44454445
language: Some("plpgsql".into()),
44464446
behavior: None,
44474447
called_on_null: None,
@@ -4484,7 +4484,7 @@ $$"#;
44844484
DataType::Int(None)
44854485
)
44864486
]),
4487-
return_type: Some(DataType::Boolean),
4487+
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
44884488
language: Some("plpgsql".into()),
44894489
behavior: None,
44904490
called_on_null: None,
@@ -4531,7 +4531,7 @@ $$"#;
45314531
DataType::Int(None)
45324532
),
45334533
]),
4534-
return_type: Some(DataType::Boolean),
4534+
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
45354535
language: Some("plpgsql".into()),
45364536
behavior: None,
45374537
called_on_null: None,
@@ -4578,7 +4578,7 @@ $$"#;
45784578
DataType::Int(None)
45794579
),
45804580
]),
4581-
return_type: Some(DataType::Boolean),
4581+
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
45824582
language: Some("plpgsql".into()),
45834583
behavior: None,
45844584
called_on_null: None,
@@ -4618,7 +4618,7 @@ $$"#;
46184618
),
46194619
OperateFunctionArg::with_name("b", DataType::Varchar(None)),
46204620
]),
4621-
return_type: Some(DataType::Boolean),
4621+
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
46224622
language: Some("plpgsql".into()),
46234623
behavior: None,
46244624
called_on_null: None,
@@ -4661,7 +4661,7 @@ fn parse_create_function() {
46614661
OperateFunctionArg::unnamed(DataType::Integer(None)),
46624662
OperateFunctionArg::unnamed(DataType::Integer(None)),
46634663
]),
4664-
return_type: Some(DataType::Integer(None)),
4664+
return_type: Some(FunctionReturnType::DataType(DataType::Integer(None))),
46654665
language: Some("SQL".into()),
46664666
behavior: Some(FunctionBehavior::Immutable),
46674667
called_on_null: Some(FunctionCalledOnNull::Strict),
@@ -4698,6 +4698,30 @@ fn parse_create_function_detailed() {
46984698
);
46994699
}
47004700

4701+
#[test]
4702+
fn parse_create_function_returns_setof() {
4703+
pg_and_generic().verified_stmt(
4704+
"CREATE FUNCTION get_users() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM users'",
4705+
);
4706+
pg_and_generic().verified_stmt(
4707+
"CREATE FUNCTION get_ids() RETURNS SETOF INTEGER LANGUAGE sql AS 'SELECT id FROM users'",
4708+
);
4709+
pg_and_generic().verified_stmt(
4710+
r#"CREATE FUNCTION get_all() RETURNS SETOF my_schema."MyType" LANGUAGE sql AS 'SELECT * FROM t'"#,
4711+
);
4712+
pg_and_generic().verified_stmt(
4713+
"CREATE FUNCTION get_rows() RETURNS SETOF RECORD LANGUAGE sql AS 'SELECT * FROM t'",
4714+
);
4715+
4716+
let sql = "CREATE FUNCTION get_names() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM t'";
4717+
match pg_and_generic().verified_stmt(sql) {
4718+
Statement::CreateFunction(CreateFunction { return_type, .. }) => {
4719+
assert_eq!(return_type, Some(FunctionReturnType::SetOf(DataType::Text)));
4720+
}
4721+
_ => panic!("Expected CreateFunction"),
4722+
}
4723+
}
4724+
47014725
#[test]
47024726
fn parse_create_function_with_security() {
47034727
let sql =
@@ -4773,10 +4797,10 @@ fn parse_create_function_c_with_module_pathname() {
47734797
"input",
47744798
DataType::Custom(ObjectName::from(vec![Ident::new("cstring")]), vec![]),
47754799
),]),
4776-
return_type: Some(DataType::Custom(
4800+
return_type: Some(FunctionReturnType::DataType(DataType::Custom(
47774801
ObjectName::from(vec![Ident::new("cas")]),
47784802
vec![]
4779-
)),
4803+
))),
47804804
language: Some("c".into()),
47814805
behavior: Some(FunctionBehavior::Immutable),
47824806
called_on_null: None,
@@ -6493,7 +6517,7 @@ fn parse_trigger_related_functions() {
64936517
if_not_exists: false,
64946518
name: ObjectName::from(vec![Ident::new("emp_stamp")]),
64956519
args: Some(vec![]),
6496-
return_type: Some(DataType::Trigger),
6520+
return_type: Some(FunctionReturnType::DataType(DataType::Trigger)),
64976521
function_body: Some(
64986522
CreateFunctionBody::AsBeforeOptions {
64996523
body: Expr::Value((

0 commit comments

Comments
 (0)