Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions convergence-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlSta
use convergence::protocol_ext::DataRowBatch;
use datafusion::arrow::array::timezone::Tz;
use datafusion::arrow::array::{
BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
BooleanArray, Date32Array, Date64Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, StringArray, StringViewArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -64,6 +65,54 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
})?)
}
DataType::Time32(unit) => match unit {
TimeUnit::Second => {
row.write_time(array_val!(Time32SecondArray, col, row_idx, value_as_time).ok_or_else(
|| ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type"),
)?)
}
TimeUnit::Millisecond => row.write_time(
array_val!(Time32MillisecondArray, col, row_idx, value_as_time).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
_ => {}
},
DataType::Time64(unit) => match unit {
TimeUnit::Microsecond => row.write_time(
array_val!(Time64MicrosecondArray, col, row_idx, value_as_time).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
TimeUnit::Nanosecond => row.write_time(
array_val!(Time64NanosecondArray, col, row_idx, value_as_time).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
_ => {}
},
DataType::Duration(unit) => match unit {
TimeUnit::Second => row.write_duration(
array_val!(DurationSecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
TimeUnit::Millisecond => row.write_duration(
array_val!(DurationMillisecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
TimeUnit::Microsecond => row.write_duration(
array_val!(DurationMicrosecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
TimeUnit::Nanosecond => row.write_duration(
array_val!(DurationNanosecondArray, col, row_idx, value_as_duration).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type")
})?,
),
},
DataType::Timestamp(unit, tz) => {
match tz {
Some(tz) => {
Expand Down Expand Up @@ -140,6 +189,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
DataType::Decimal128(_, _) => DataTypeOid::Numeric,
DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
DataType::Time32(_) | DataType::Time64(_) => DataTypeOid::Time,
DataType::Duration(_) => DataTypeOid::Interval,
DataType::Timestamp(_, tz) => match tz {
Some(_) => DataTypeOid::Timestamptz,
None => DataTypeOid::Timestamp,
Expand Down
184 changes: 141 additions & 43 deletions convergence-arrow/tests/test_arrow.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use async_trait::async_trait;
use chrono::{DateTime, NaiveDate, NaiveDateTime};
use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta};
use convergence::engine::{Engine, Portal};
use convergence::protocol::{ErrorResponse, FieldDescription};
use convergence::protocol_ext::DataRowBatch;
use convergence::server::{self, BindOptions};
use convergence::sqlparser::ast::Statement;
use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc};
use datafusion::arrow::array::{
ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray,
TimestampSecondArray,
ArrayRef, Date32Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
DurationNanosecondArray, DurationSecondArray, Float32Array, Int32Array, StringArray, StringViewArray,
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
use rust_decimal::Decimal;
use std::convert::TryInto;
use std::sync::Arc;
use tokio_postgres::types::{FromSql, Type};
use tokio_postgres::{connect, NoTls};

struct ArrowPortal {
Expand Down Expand Up @@ -47,6 +50,23 @@ impl ArrowEngine {
Arc::new(TimestampSecondArray::from(vec![1577854800, 1580533200, 1583038800]).with_timezone("+05:00"))
as ArrayRef;
let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef;
let time_s_col = Arc::new(Time32SecondArray::from(vec![30, 60, 90])) as ArrayRef;
let time_ms_col = Arc::new(Time32MillisecondArray::from(vec![30_000, 60_000, 90_000])) as ArrayRef;
let time_mcs_col = Arc::new(Time64MicrosecondArray::from(vec![30_000_000, 60_000_000, 90_000_000])) as ArrayRef;
let time_ns_col = Arc::new(Time64NanosecondArray::from(vec![
30_000_000_000,
60_000_000_000,
90_000_000_000,
])) as ArrayRef;
let duration_s_col = Arc::new(DurationSecondArray::from(vec![3, 6, 9])) as ArrayRef;
let duration_ms_col = Arc::new(DurationMillisecondArray::from(vec![3_000, 6_000, 9_000])) as ArrayRef;
let duration_mcs_col =
Arc::new(DurationMicrosecondArray::from(vec![3_000_000, 6_000_000, 9_000_000])) as ArrayRef;
let duration_ns_col = Arc::new(DurationNanosecondArray::from(vec![
3_000_000_000,
6_000_000_000,
9_000_000_000,
])) as ArrayRef;

let schema = Schema::new(vec![
Field::new("int_col", DataType::Int32, true),
Expand All @@ -61,6 +81,14 @@ impl ArrowEngine {
true,
),
Field::new("date_col", DataType::Date32, true),
Field::new("time_s_col", DataType::Time32(TimeUnit::Second), true),
Field::new("time_ms_col", DataType::Time32(TimeUnit::Millisecond), true),
Field::new("time_mcs_col", DataType::Time64(TimeUnit::Microsecond), true),
Field::new("time_ns_col", DataType::Time64(TimeUnit::Nanosecond), true),
Field::new("duration_s_col", DataType::Duration(TimeUnit::Second), true),
Field::new("duration_ms_col", DataType::Duration(TimeUnit::Millisecond), true),
Field::new("duration_mcs_col", DataType::Duration(TimeUnit::Microsecond), true),
Field::new("duration_ns_col", DataType::Duration(TimeUnit::Nanosecond), true),
]);

Self {
Expand All @@ -75,6 +103,14 @@ impl ArrowEngine {
ts_col,
ts_tz_col,
date_col,
time_s_col,
time_ms_col,
time_mcs_col,
time_ns_col,
duration_s_col,
duration_ms_col,
duration_mcs_col,
duration_ns_col,
],
)
.expect("failed to create batch"),
Expand Down Expand Up @@ -114,14 +150,45 @@ async fn setup() -> tokio_postgres::Client {
client
}

// remove after https://github.com/sfackler/rust-postgres/pull/1238 is merged
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
struct DurationWrapper(TimeDelta);

impl<'a> FromSql<'a> for DurationWrapper {
fn from_sql(_ty: &Type, raw: &[u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let micros = i64::from_be_bytes(raw.try_into().unwrap());
Ok(DurationWrapper(Duration::microseconds(micros)))
}
fn accepts(ty: &Type) -> bool {
matches!(ty, &Type::INTERVAL)
}
}

#[tokio::test]
async fn basic_data_types() {
let client = setup().await;

let rows = client.query("select 1", &[]).await.unwrap();
let get_row = |idx: usize| {
let row = &rows[idx];
let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, DateTime<_>, NaiveDate) = (
let cols: (
i32,
f32,
Decimal,
&str,
&str,
NaiveDateTime,
DateTime<FixedOffset>,
NaiveDate,
NaiveTime,
NaiveTime,
NaiveTime,
NaiveTime,
DurationWrapper,
DurationWrapper,
DurationWrapper,
DurationWrapper,
) = (
row.get(0),
row.get(1),
row.get(2),
Expand All @@ -130,56 +197,87 @@ async fn basic_data_types() {
row.get(5),
row.get(6),
row.get(7),
row.get(8),
row.get(9),
row.get(10),
row.get(11),
row.get(12),
row.get(13),
row.get(14),
row.get(15),
);
cols
};

assert_eq!(
get_row(0),
(
1,
1.5,
Decimal::from(11),
"a",
"aa",
NaiveDate::from_ymd_opt(2020, 1, 1)
let row = get_row(0);
assert!(row.0 == 1);
assert!(row.1 == 1.5);
assert!(row.2 == Decimal::from(11));
assert!(row.3 == "a");
assert!(row.4 == "aa");
assert!(
row.5
== NaiveDate::from_ymd_opt(2020, 1, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap(),
DateTime::from_timestamp_millis(1577854800000).unwrap(),
NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(),
)
.unwrap()
);
assert_eq!(
get_row(1),
(
2,
2.5,
Decimal::from(22),
"b",
"bb",
NaiveDate::from_ymd_opt(2020, 2, 1)
assert!(row.6 == DateTime::from_timestamp_millis(1577854800000).unwrap());
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 1).unwrap());
assert!(row.8 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
assert!(row.9 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
assert!(row.10 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
assert!(row.11 == NaiveTime::from_hms_opt(0, 0, 30).unwrap());
assert!(row.12 == DurationWrapper(Duration::seconds(3)));
assert!(row.13 == DurationWrapper(Duration::seconds(3)));
assert!(row.14 == DurationWrapper(Duration::seconds(3)));
assert!(row.15 == DurationWrapper(Duration::seconds(3)));

let row = get_row(1);
assert!(row.0 == 2);
assert!(row.1 == 2.5);
assert!(row.2 == Decimal::from(22));
assert!(row.3 == "b");
assert!(row.4 == "bb");
assert!(
row.5
== NaiveDate::from_ymd_opt(2020, 2, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap(),
DateTime::from_timestamp_millis(1580533200000).unwrap(),
NaiveDate::from_ymd_opt(1970, 1, 2).unwrap()
)
.unwrap()
);
assert_eq!(
get_row(2),
(
3,
3.5,
Decimal::from(33),
"c",
"cc",
NaiveDate::from_ymd_opt(2020, 3, 1)
assert!(row.6 == DateTime::from_timestamp_millis(1580533200000).unwrap());
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 2).unwrap());
assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 0).unwrap());
assert!(row.12 == DurationWrapper(Duration::seconds(6)));
assert!(row.13 == DurationWrapper(Duration::seconds(6)));
assert!(row.14 == DurationWrapper(Duration::seconds(6)));
assert!(row.15 == DurationWrapper(Duration::seconds(6)));

let row = get_row(2);
assert!(row.0 == 3);
assert!(row.1 == 3.5);
assert!(row.2 == Decimal::from(33));
assert!(row.3 == "c");
assert!(row.4 == "cc");
assert!(
row.5
== NaiveDate::from_ymd_opt(2020, 3, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap(),
DateTime::from_timestamp_millis(1583038800000).unwrap(),
NaiveDate::from_ymd_opt(1970, 1, 3).unwrap()
)
.unwrap()
);
assert!(row.6 == DateTime::from_timestamp_millis(1583038800000).unwrap());
assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 3).unwrap());
assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 30).unwrap());
assert!(row.12 == DurationWrapper(Duration::seconds(9)));
assert!(row.13 == DurationWrapper(Duration::seconds(9)));
assert!(row.14 == DurationWrapper(Duration::seconds(9)));
assert!(row.15 == DurationWrapper(Duration::seconds(9)));
}
5 changes: 5 additions & 0 deletions convergence/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ data_types! {
Numeric = 1700, -1

Date = 1082, 4

Time = 1083, 8

Timestamp = 1114, 8
Timestamptz = 1184, 8

Interval = 1186, 16

Text = 25, -1
}

Expand Down
28 changes: 26 additions & 2 deletions convergence/src/protocol_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription};
use bytes::{BufMut, BytesMut};
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime};
use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
use rust_decimal::Decimal;
use tokio_postgres::types::{ToSql, Type};
use tokio_util::codec::Encoder;
Expand Down Expand Up @@ -119,7 +119,31 @@ impl<'a> DataRowWriter<'a> {
}
}

/// Writes a timestamp value for the next column.
/// Writes a time value for the next column.
pub fn write_time(&mut self, val: NaiveTime) {
match self.parent.format_code {
FormatCode::Binary => {
self.write_int8((val.num_seconds_from_midnight() * 1_000_000 + val.nanosecond() / 1_000) as i64);
}
FormatCode::Text => self.write_string(&val.to_string()),
}
}

/// Writes a time value for the next column.
pub fn write_duration(&mut self, val: Duration) {
match self.parent.format_code {
FormatCode::Binary => {
let total_micros = val.num_microseconds().unwrap_or_else(|| {
// Fallback for very large durations that may not fit in i64 microseconds
val.num_seconds() * 1_000_000 + (val.subsec_nanos() as i64) / 1_000
});
self.write_int8(total_micros);
}
FormatCode::Text => self.write_string(&val.to_string()),
}
}

/// Writes a timestamp value fxor the next column.
pub fn write_timestamp(&mut self, val: NaiveDateTime) {
match self.parent.format_code {
FormatCode::Binary => {
Expand Down
Loading