Skip to content

Commit 7cbafd5

Browse files
authored
feat(cubesql): Support forwards directions for FETCH statement (#10377)
Previously, only `FETCH <count>` was supported. This adds handling for: - FETCH NEXT (returns 1 row) - FETCH FORWARD [count] (returns count rows, or 1 if omitted) - FETCH ALL / FETCH FORWARD ALL (returns all remaining rows)
1 parent f69f5ad commit 7cbafd5

8 files changed

Lines changed: 368 additions & 217 deletions

File tree

rust/cubesql/cubesql/e2e/tests/postgres.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,109 @@ impl PostgresIntegrationTestSuite {
666666
Ok(())
667667
}
668668

669+
async fn test_fetch_directions(&self) -> RunResult<()> {
670+
self.test_simple_query(
671+
r#"DECLARE test_fetch_directions CURSOR WITH HOLD FOR SELECT generate_series(1, 100);"#
672+
.to_string(),
673+
|_| {},
674+
)
675+
.await?;
676+
677+
// Test FETCH FORWARD 1 - should return row "1"
678+
self.test_simple_query(
679+
r#"FETCH FORWARD 1 IN test_fetch_directions;"#.to_string(),
680+
|messages| {
681+
assert_eq!(messages.len(), 2); // 1 row + completion
682+
if let SimpleQueryMessage::Row(row) = &messages[0] {
683+
assert_eq!(row.get(0), Some("1"));
684+
} else {
685+
panic!("Expected Row for FETCH FORWARD 1");
686+
}
687+
},
688+
)
689+
.await?;
690+
691+
// Test FETCH NEXT - should return row "2"
692+
self.test_simple_query(
693+
r#"FETCH NEXT IN test_fetch_directions;"#.to_string(),
694+
|messages| {
695+
assert_eq!(messages.len(), 2); // 1 row + completion
696+
if let SimpleQueryMessage::Row(row) = &messages[0] {
697+
assert_eq!(row.get(0), Some("2"));
698+
} else {
699+
panic!("Expected Row for FETCH NEXT");
700+
}
701+
},
702+
)
703+
.await?;
704+
705+
// Test FETCH FORWARD 5 - should return rows 3-7
706+
self.test_simple_query(
707+
r#"FETCH FORWARD 5 IN test_fetch_directions;"#.to_string(),
708+
|messages| {
709+
assert_eq!(messages.len(), 6); // 5 rows + completion
710+
if let SimpleQueryMessage::Row(row) = &messages[0] {
711+
assert_eq!(row.get(0), Some("3"));
712+
} else {
713+
panic!("Expected Row for FETCH FORWARD 5, first row");
714+
}
715+
if let SimpleQueryMessage::Row(row) = &messages[4] {
716+
assert_eq!(row.get(0), Some("7"));
717+
} else {
718+
panic!("Expected Row for FETCH FORWARD 5, last row");
719+
}
720+
},
721+
)
722+
.await?;
723+
724+
// Test FETCH ALL - should return remaining rows (8-100 = 93 rows)
725+
self.test_simple_query(
726+
r#"FETCH ALL IN test_fetch_directions;"#.to_string(),
727+
|messages| {
728+
// 93 rows + 1 completion
729+
assert_eq!(messages.len(), 94);
730+
if let SimpleQueryMessage::Row(row) = &messages[0] {
731+
assert_eq!(row.get(0), Some("8"));
732+
} else {
733+
panic!("Expected Row for FETCH ALL, first row");
734+
}
735+
if let SimpleQueryMessage::Row(row) = &messages[92] {
736+
assert_eq!(row.get(0), Some("100"));
737+
} else {
738+
panic!("Expected Row for FETCH ALL, last row");
739+
}
740+
},
741+
)
742+
.await?;
743+
744+
self.test_simple_query(r#"CLOSE test_fetch_directions;"#.to_string(), |_| {})
745+
.await?;
746+
747+
Ok(())
748+
}
749+
750+
async fn test_fetch_forward_all(&self) -> RunResult<()> {
751+
self.test_simple_query(
752+
r#"DECLARE test_forward_all CURSOR WITH HOLD FOR SELECT generate_series(1, 10);"#
753+
.to_string(),
754+
|_| {},
755+
)
756+
.await?;
757+
758+
self.test_simple_query(
759+
r#"FETCH FORWARD ALL IN test_forward_all;"#.to_string(),
760+
|messages| {
761+
assert_eq!(messages.len(), 11); // 10 rows + 1 completion
762+
},
763+
)
764+
.await?;
765+
766+
self.test_simple_query(r#"CLOSE test_forward_all;"#.to_string(), |_| {})
767+
.await?;
768+
769+
Ok(())
770+
}
771+
669772
// Tableau Desktop uses it
670773
async fn test_simple_cursors_without_hold(&self) -> RunResult<()> {
671774
// without hold is default behaviour
@@ -1175,6 +1278,8 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
11751278
self.test_stream_single().await?;
11761279
self.test_portal_pagination().await?;
11771280
self.test_simple_cursors().await?;
1281+
self.test_fetch_directions().await?;
1282+
self.test_fetch_forward_all().await?;
11781283
self.test_simple_cursors_without_hold().await?;
11791284
self.test_simple_cursors_close_specific().await?;
11801285
self.test_simple_cursors_close_all().await?;
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use std::sync::Arc;
2+
3+
use pg_srv::protocol;
4+
use sqlparser::ast::Value;
5+
6+
use super::error::ConnectionError;
7+
use crate::transport::SpanId;
8+
9+
pub fn parse_fetch_limit(
10+
limit: &Value,
11+
span_id: &Option<Arc<SpanId>>,
12+
) -> Result<usize, ConnectionError> {
13+
match limit {
14+
Value::Number(v, negative) => {
15+
if *negative {
16+
return Err(ConnectionError::Protocol(
17+
protocol::ErrorResponse::error(
18+
protocol::ErrorCode::ObjectNotInPrerequisiteState,
19+
"cursor can only scan forward".to_string(),
20+
)
21+
.into(),
22+
span_id.clone(),
23+
));
24+
}
25+
v.parse::<usize>().map_err(|err| {
26+
ConnectionError::Protocol(
27+
protocol::ErrorResponse::error(
28+
protocol::ErrorCode::ProtocolViolation,
29+
format!(r#"Unable to parse number "{}" for fetch limit: {}"#, v, err),
30+
)
31+
.into(),
32+
span_id.clone(),
33+
)
34+
})
35+
}
36+
other => Err(ConnectionError::Protocol(
37+
protocol::ErrorResponse::error(
38+
protocol::ErrorCode::ProtocolViolation,
39+
format!("FETCH limit must be a number, got: {}", other),
40+
)
41+
.into(),
42+
span_id.clone(),
43+
)),
44+
}
45+
}
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
use std::{backtrace::Backtrace, sync::Arc};
2+
3+
use datafusion::{arrow::error::ArrowError, error::DataFusionError};
4+
use pg_srv::{
5+
protocol::{self, ErrorResponse},
6+
ProtocolError,
7+
};
8+
9+
use crate::{compile::CompilationError, transport::SpanId, CubeError};
10+
11+
#[derive(thiserror::Error, Debug)]
12+
pub enum ConnectionError {
13+
#[error("CubeError: {0}")]
14+
Cube(CubeError, Option<Arc<SpanId>>),
15+
#[error("DataFusionError: {0}")]
16+
DataFusion(DataFusionError, Option<Arc<SpanId>>),
17+
#[error("ArrowError: {0}")]
18+
Arrow(ArrowError, Option<Arc<SpanId>>),
19+
#[error("CompilationError: {0}")]
20+
CompilationError(CompilationError, Option<Arc<SpanId>>),
21+
#[error("ProtocolError: {0}")]
22+
Protocol(ProtocolError, Option<Arc<SpanId>>),
23+
}
24+
25+
impl ConnectionError {
26+
/// Return Backtrace from any variant of Enum
27+
pub fn backtrace(&self) -> Option<&Backtrace> {
28+
match &self {
29+
ConnectionError::Cube(e, _) => e.backtrace(),
30+
ConnectionError::CompilationError(e, _) => e.backtrace(),
31+
ConnectionError::Protocol(e, _) => e.backtrace(),
32+
ConnectionError::DataFusion(_, _) | ConnectionError::Arrow(_, _) => None,
33+
}
34+
}
35+
36+
/// Converts Error to protocol::ErrorResponse which is usefully for writing response to the client
37+
pub fn to_error_response(self) -> protocol::ErrorResponse {
38+
match self {
39+
ConnectionError::Cube(e, _) => Self::cube_to_error_response(&e),
40+
ConnectionError::DataFusion(e, _) => Self::df_to_error_response(&e),
41+
ConnectionError::Arrow(e, _) => Self::arrow_to_error_response(&e),
42+
ConnectionError::CompilationError(e, _) => {
43+
fn to_error_response(e: CompilationError) -> protocol::ErrorResponse {
44+
match e {
45+
CompilationError::Internal(_, _, _) => protocol::ErrorResponse::error(
46+
protocol::ErrorCode::InternalError,
47+
e.to_string(),
48+
),
49+
CompilationError::User(_, _) => protocol::ErrorResponse::error(
50+
protocol::ErrorCode::InvalidSqlStatement,
51+
e.to_string(),
52+
),
53+
CompilationError::Unsupported(_, _) => protocol::ErrorResponse::error(
54+
protocol::ErrorCode::FeatureNotSupported,
55+
e.to_string(),
56+
),
57+
CompilationError::Fatal(_, _) => protocol::ErrorResponse::fatal(
58+
protocol::ErrorCode::InternalError,
59+
e.to_string(),
60+
),
61+
}
62+
}
63+
64+
to_error_response(e)
65+
}
66+
ConnectionError::Protocol(e, _) => e.to_error_response(),
67+
}
68+
}
69+
70+
pub fn with_span_id(self, span_id: Option<Arc<SpanId>>) -> Self {
71+
match self {
72+
ConnectionError::Cube(e, _) => ConnectionError::Cube(e, span_id),
73+
ConnectionError::DataFusion(e, _) => ConnectionError::DataFusion(e, span_id),
74+
ConnectionError::Arrow(e, _) => ConnectionError::Arrow(e, span_id),
75+
ConnectionError::CompilationError(e, _) => {
76+
ConnectionError::CompilationError(e, span_id)
77+
}
78+
ConnectionError::Protocol(e, _) => ConnectionError::Protocol(e, span_id),
79+
}
80+
}
81+
82+
pub fn span_id(&self) -> Option<Arc<SpanId>> {
83+
match self {
84+
ConnectionError::Cube(_, span_id) => span_id.clone(),
85+
ConnectionError::DataFusion(_, span_id) => span_id.clone(),
86+
ConnectionError::Arrow(_, span_id) => span_id.clone(),
87+
ConnectionError::CompilationError(_, span_id) => span_id.clone(),
88+
ConnectionError::Protocol(_, span_id) => span_id.clone(),
89+
}
90+
}
91+
92+
fn cube_to_error_response(e: &CubeError) -> protocol::ErrorResponse {
93+
let message = e.to_string();
94+
// Remove `Error: ` prefix that can come from JS
95+
let message = if let Some(message) = message.strip_prefix("Error: ") {
96+
message.to_string()
97+
} else {
98+
message
99+
};
100+
protocol::ErrorResponse::error(protocol::ErrorCode::InternalError, message)
101+
}
102+
103+
fn df_to_error_response(e: &DataFusionError) -> protocol::ErrorResponse {
104+
match e {
105+
DataFusionError::ArrowError(arrow_err) => {
106+
return Self::arrow_to_error_response(arrow_err);
107+
}
108+
DataFusionError::External(err) => {
109+
if let Some(cube_err) = err.downcast_ref::<CubeError>() {
110+
return Self::cube_to_error_response(cube_err);
111+
}
112+
}
113+
_ => {}
114+
}
115+
protocol::ErrorResponse::error(
116+
protocol::ErrorCode::InternalError,
117+
format!("Post-processing Error: {}", e),
118+
)
119+
}
120+
121+
fn arrow_to_error_response(e: &ArrowError) -> protocol::ErrorResponse {
122+
match e {
123+
ArrowError::ExternalError(err) => {
124+
if let Some(df_err) = err.downcast_ref::<DataFusionError>() {
125+
return Self::df_to_error_response(df_err);
126+
}
127+
if let Some(cube_err) = err.downcast_ref::<CubeError>() {
128+
return Self::cube_to_error_response(cube_err);
129+
}
130+
}
131+
_ => {}
132+
}
133+
protocol::ErrorResponse::error(
134+
protocol::ErrorCode::InternalError,
135+
format!("Post-processing Error: {}", e),
136+
)
137+
}
138+
}
139+
140+
impl From<CubeError> for ConnectionError {
141+
fn from(e: CubeError) -> Self {
142+
ConnectionError::Cube(e, None)
143+
}
144+
}
145+
146+
impl From<CompilationError> for ConnectionError {
147+
fn from(e: CompilationError) -> Self {
148+
ConnectionError::CompilationError(e, None)
149+
}
150+
}
151+
152+
impl From<ProtocolError> for ConnectionError {
153+
fn from(e: ProtocolError) -> Self {
154+
ConnectionError::Protocol(e, None)
155+
}
156+
}
157+
158+
impl From<tokio::task::JoinError> for ConnectionError {
159+
fn from(e: tokio::task::JoinError) -> Self {
160+
ConnectionError::Cube(e.into(), None)
161+
}
162+
}
163+
164+
impl From<DataFusionError> for ConnectionError {
165+
fn from(e: DataFusionError) -> Self {
166+
ConnectionError::DataFusion(e, None)
167+
}
168+
}
169+
170+
impl From<ArrowError> for ConnectionError {
171+
fn from(e: ArrowError) -> Self {
172+
ConnectionError::Arrow(e, None)
173+
}
174+
}
175+
176+
/// Auto converting for all kind of io:Error to ConnectionError, sugar
177+
impl From<std::io::Error> for ConnectionError {
178+
fn from(e: std::io::Error) -> Self {
179+
ConnectionError::Protocol(e.into(), None)
180+
}
181+
}
182+
183+
/// Auto converting for all kind of io:Error to ConnectionError, sugar
184+
impl From<ErrorResponse> for ConnectionError {
185+
fn from(e: ErrorResponse) -> Self {
186+
ConnectionError::Protocol(e.into(), None)
187+
}
188+
}

rust/cubesql/cubesql/src/sql/postgres/extended.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use pg_srv::{protocol, BindValue, PgTypeId, ProtocolError};
1414
use sqlparser::ast;
1515
use std::{fmt, pin::Pin, sync::Arc};
1616

17-
use crate::sql::shim::{ConnectionError, QueryPlanExt};
17+
use super::{shim::QueryPlanExt, ConnectionError};
1818
use datafusion::{
1919
arrow::array::Array, dataframe::DataFrame as DFDataFrame,
2020
physical_plan::SendableRecordBatchStream,
@@ -599,7 +599,7 @@ mod tests {
599599
};
600600
use pg_srv::protocol::Format;
601601

602-
use crate::sql::{extended::PortalFrom, shim::ConnectionError};
602+
use crate::sql::{error::ConnectionError, extended::PortalFrom};
603603
use datafusion::{
604604
arrow::{
605605
array::{ArrayRef, StringArray},
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
pub(crate) mod ast_helpers;
2+
pub(crate) mod error;
13
pub(crate) mod extended;
24
pub mod pg_auth_service;
35
pub(crate) mod pg_type;
46
pub(crate) mod service;
57
pub(crate) mod shim;
68
pub(crate) mod writer;
79

10+
pub use error::ConnectionError;
811
pub use pg_type::*;
912
pub use service::*;

0 commit comments

Comments
 (0)