diff --git a/convergence-arrow/src/table.rs b/convergence-arrow/src/table.rs index cd89418..0c521b4 100644 --- a/convergence-arrow/src/table.rs +++ b/convergence-arrow/src/table.rs @@ -6,10 +6,11 @@ use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlSta use convergence::protocol_ext::DataRowBatch; use datafusion::arrow::array::timezone::Tz; use datafusion::arrow::array::{ - BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + BooleanArray, Date32Array, Date64Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray, + DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, StringArray, StringViewArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; @@ -64,6 +65,54 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type") })?) } + DataType::Time32(unit) => match unit { + TimeUnit::Second => { + row.write_time(array_val!(Time32SecondArray, col, row_idx, value_as_time).ok_or_else( + || ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type"), + )?) + } + TimeUnit::Millisecond => row.write_time( + array_val!(Time32MillisecondArray, col, row_idx, value_as_time).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + _ => {} + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => row.write_time( + array_val!(Time64MicrosecondArray, col, row_idx, value_as_time).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + TimeUnit::Nanosecond => row.write_time( + array_val!(Time64NanosecondArray, col, row_idx, value_as_time).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + _ => {} + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => row.write_duration( + array_val!(DurationSecondArray, col, row_idx, value_as_duration).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + TimeUnit::Millisecond => row.write_duration( + array_val!(DurationMillisecondArray, col, row_idx, value_as_duration).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + TimeUnit::Microsecond => row.write_duration( + array_val!(DurationMicrosecondArray, col, row_idx, value_as_duration).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + TimeUnit::Nanosecond => row.write_duration( + array_val!(DurationNanosecondArray, col, row_idx, value_as_duration).ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported time type") + })?, + ), + }, DataType::Timestamp(unit, tz) => { match tz { Some(tz) => { @@ -140,6 +189,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result { DataType::Decimal128(_, _) => DataTypeOid::Numeric, DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text, DataType::Date32 | DataType::Date64 => DataTypeOid::Date, + DataType::Time32(_) | DataType::Time64(_) => DataTypeOid::Time, + DataType::Duration(_) => DataTypeOid::Interval, DataType::Timestamp(_, tz) => match tz { Some(_) => DataTypeOid::Timestamptz, None => DataTypeOid::Timestamp, diff --git a/convergence-arrow/tests/test_arrow.rs b/convergence-arrow/tests/test_arrow.rs index 0862a27..86af344 100644 --- a/convergence-arrow/tests/test_arrow.rs +++ b/convergence-arrow/tests/test_arrow.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use chrono::{DateTime, NaiveDate, NaiveDateTime}; +use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; use convergence::engine::{Engine, Portal}; use convergence::protocol::{ErrorResponse, FieldDescription}; use convergence::protocol_ext::DataRowBatch; @@ -7,13 +7,16 @@ use convergence::server::{self, BindOptions}; use convergence::sqlparser::ast::Statement; use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc}; use datafusion::arrow::array::{ - ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, - TimestampSecondArray, + ArrayRef, Date32Array, Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray, + DurationNanosecondArray, DurationSecondArray, Float32Array, Int32Array, StringArray, StringViewArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; use rust_decimal::Decimal; +use std::convert::TryInto; use std::sync::Arc; +use tokio_postgres::types::{FromSql, Type}; use tokio_postgres::{connect, NoTls}; struct ArrowPortal { @@ -47,6 +50,23 @@ impl ArrowEngine { Arc::new(TimestampSecondArray::from(vec![1577854800, 1580533200, 1583038800]).with_timezone("+05:00")) as ArrayRef; let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef; + let time_s_col = Arc::new(Time32SecondArray::from(vec![30, 60, 90])) as ArrayRef; + let time_ms_col = Arc::new(Time32MillisecondArray::from(vec![30_000, 60_000, 90_000])) as ArrayRef; + let time_mcs_col = Arc::new(Time64MicrosecondArray::from(vec![30_000_000, 60_000_000, 90_000_000])) as ArrayRef; + let time_ns_col = Arc::new(Time64NanosecondArray::from(vec![ + 30_000_000_000, + 60_000_000_000, + 90_000_000_000, + ])) as ArrayRef; + let duration_s_col = Arc::new(DurationSecondArray::from(vec![3, 6, 9])) as ArrayRef; + let duration_ms_col = Arc::new(DurationMillisecondArray::from(vec![3_000, 6_000, 9_000])) as ArrayRef; + let duration_mcs_col = + Arc::new(DurationMicrosecondArray::from(vec![3_000_000, 6_000_000, 9_000_000])) as ArrayRef; + let duration_ns_col = Arc::new(DurationNanosecondArray::from(vec![ + 3_000_000_000, + 6_000_000_000, + 9_000_000_000, + ])) as ArrayRef; let schema = Schema::new(vec![ Field::new("int_col", DataType::Int32, true), @@ -61,6 +81,14 @@ impl ArrowEngine { true, ), Field::new("date_col", DataType::Date32, true), + Field::new("time_s_col", DataType::Time32(TimeUnit::Second), true), + Field::new("time_ms_col", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_mcs_col", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_ns_col", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("duration_s_col", DataType::Duration(TimeUnit::Second), true), + Field::new("duration_ms_col", DataType::Duration(TimeUnit::Millisecond), true), + Field::new("duration_mcs_col", DataType::Duration(TimeUnit::Microsecond), true), + Field::new("duration_ns_col", DataType::Duration(TimeUnit::Nanosecond), true), ]); Self { @@ -75,6 +103,14 @@ impl ArrowEngine { ts_col, ts_tz_col, date_col, + time_s_col, + time_ms_col, + time_mcs_col, + time_ns_col, + duration_s_col, + duration_ms_col, + duration_mcs_col, + duration_ns_col, ], ) .expect("failed to create batch"), @@ -114,6 +150,20 @@ async fn setup() -> tokio_postgres::Client { client } +// remove after https://github.com/sfackler/rust-postgres/pull/1238 is merged +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +struct DurationWrapper(TimeDelta); + +impl<'a> FromSql<'a> for DurationWrapper { + fn from_sql(_ty: &Type, raw: &[u8]) -> Result> { + let micros = i64::from_be_bytes(raw.try_into().unwrap()); + Ok(DurationWrapper(Duration::microseconds(micros))) + } + fn accepts(ty: &Type) -> bool { + matches!(ty, &Type::INTERVAL) + } +} + #[tokio::test] async fn basic_data_types() { let client = setup().await; @@ -121,7 +171,24 @@ async fn basic_data_types() { let rows = client.query("select 1", &[]).await.unwrap(); let get_row = |idx: usize| { let row = &rows[idx]; - let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, DateTime<_>, NaiveDate) = ( + let cols: ( + i32, + f32, + Decimal, + &str, + &str, + NaiveDateTime, + DateTime, + NaiveDate, + NaiveTime, + NaiveTime, + NaiveTime, + NaiveTime, + DurationWrapper, + DurationWrapper, + DurationWrapper, + DurationWrapper, + ) = ( row.get(0), row.get(1), row.get(2), @@ -130,56 +197,87 @@ async fn basic_data_types() { row.get(5), row.get(6), row.get(7), + row.get(8), + row.get(9), + row.get(10), + row.get(11), + row.get(12), + row.get(13), + row.get(14), + row.get(15), ); cols }; - assert_eq!( - get_row(0), - ( - 1, - 1.5, - Decimal::from(11), - "a", - "aa", - NaiveDate::from_ymd_opt(2020, 1, 1) + let row = get_row(0); + assert!(row.0 == 1); + assert!(row.1 == 1.5); + assert!(row.2 == Decimal::from(11)); + assert!(row.3 == "a"); + assert!(row.4 == "aa"); + assert!( + row.5 + == NaiveDate::from_ymd_opt(2020, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) - .unwrap(), - DateTime::from_timestamp_millis(1577854800000).unwrap(), - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(), - ) + .unwrap() ); - assert_eq!( - get_row(1), - ( - 2, - 2.5, - Decimal::from(22), - "b", - "bb", - NaiveDate::from_ymd_opt(2020, 2, 1) + assert!(row.6 == DateTime::from_timestamp_millis(1577854800000).unwrap()); + assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()); + assert!(row.8 == NaiveTime::from_hms_opt(0, 0, 30).unwrap()); + assert!(row.9 == NaiveTime::from_hms_opt(0, 0, 30).unwrap()); + assert!(row.10 == NaiveTime::from_hms_opt(0, 0, 30).unwrap()); + assert!(row.11 == NaiveTime::from_hms_opt(0, 0, 30).unwrap()); + assert!(row.12 == DurationWrapper(Duration::seconds(3))); + assert!(row.13 == DurationWrapper(Duration::seconds(3))); + assert!(row.14 == DurationWrapper(Duration::seconds(3))); + assert!(row.15 == DurationWrapper(Duration::seconds(3))); + + let row = get_row(1); + assert!(row.0 == 2); + assert!(row.1 == 2.5); + assert!(row.2 == Decimal::from(22)); + assert!(row.3 == "b"); + assert!(row.4 == "bb"); + assert!( + row.5 + == NaiveDate::from_ymd_opt(2020, 2, 1) .unwrap() .and_hms_opt(0, 0, 0) - .unwrap(), - DateTime::from_timestamp_millis(1580533200000).unwrap(), - NaiveDate::from_ymd_opt(1970, 1, 2).unwrap() - ) + .unwrap() ); - assert_eq!( - get_row(2), - ( - 3, - 3.5, - Decimal::from(33), - "c", - "cc", - NaiveDate::from_ymd_opt(2020, 3, 1) + assert!(row.6 == DateTime::from_timestamp_millis(1580533200000).unwrap()); + assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 2).unwrap()); + assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 0).unwrap()); + assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 0).unwrap()); + assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 0).unwrap()); + assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 0).unwrap()); + assert!(row.12 == DurationWrapper(Duration::seconds(6))); + assert!(row.13 == DurationWrapper(Duration::seconds(6))); + assert!(row.14 == DurationWrapper(Duration::seconds(6))); + assert!(row.15 == DurationWrapper(Duration::seconds(6))); + + let row = get_row(2); + assert!(row.0 == 3); + assert!(row.1 == 3.5); + assert!(row.2 == Decimal::from(33)); + assert!(row.3 == "c"); + assert!(row.4 == "cc"); + assert!( + row.5 + == NaiveDate::from_ymd_opt(2020, 3, 1) .unwrap() .and_hms_opt(0, 0, 0) - .unwrap(), - DateTime::from_timestamp_millis(1583038800000).unwrap(), - NaiveDate::from_ymd_opt(1970, 1, 3).unwrap() - ) + .unwrap() ); + assert!(row.6 == DateTime::from_timestamp_millis(1583038800000).unwrap()); + assert!(row.7 == NaiveDate::from_ymd_opt(1970, 1, 3).unwrap()); + assert!(row.8 == NaiveTime::from_hms_opt(0, 1, 30).unwrap()); + assert!(row.9 == NaiveTime::from_hms_opt(0, 1, 30).unwrap()); + assert!(row.10 == NaiveTime::from_hms_opt(0, 1, 30).unwrap()); + assert!(row.11 == NaiveTime::from_hms_opt(0, 1, 30).unwrap()); + assert!(row.12 == DurationWrapper(Duration::seconds(9))); + assert!(row.13 == DurationWrapper(Duration::seconds(9))); + assert!(row.14 == DurationWrapper(Duration::seconds(9))); + assert!(row.15 == DurationWrapper(Duration::seconds(9))); } diff --git a/convergence/src/protocol.rs b/convergence/src/protocol.rs index 5b6b517..6cc1865 100644 --- a/convergence/src/protocol.rs +++ b/convergence/src/protocol.rs @@ -78,9 +78,14 @@ data_types! { Numeric = 1700, -1 Date = 1082, 4 + + Time = 1083, 8 + Timestamp = 1114, 8 Timestamptz = 1184, 8 + Interval = 1186, 16 + Text = 25, -1 } diff --git a/convergence/src/protocol_ext.rs b/convergence/src/protocol_ext.rs index c47037b..a1b6e33 100644 --- a/convergence/src/protocol_ext.rs +++ b/convergence/src/protocol_ext.rs @@ -2,7 +2,7 @@ use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; use bytes::{BufMut, BytesMut}; -use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime}; +use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; use rust_decimal::Decimal; use tokio_postgres::types::{ToSql, Type}; use tokio_util::codec::Encoder; @@ -119,7 +119,31 @@ impl<'a> DataRowWriter<'a> { } } - /// Writes a timestamp value for the next column. + /// Writes a time value for the next column. + pub fn write_time(&mut self, val: NaiveTime) { + match self.parent.format_code { + FormatCode::Binary => { + self.write_int8((val.num_seconds_from_midnight() * 1_000_000 + val.nanosecond() / 1_000) as i64); + } + FormatCode::Text => self.write_string(&val.to_string()), + } + } + + /// Writes a time value for the next column. + pub fn write_duration(&mut self, val: Duration) { + match self.parent.format_code { + FormatCode::Binary => { + let total_micros = val.num_microseconds().unwrap_or_else(|| { + // Fallback for very large durations that may not fit in i64 microseconds + val.num_seconds() * 1_000_000 + (val.subsec_nanos() as i64) / 1_000 + }); + self.write_int8(total_micros); + } + FormatCode::Text => self.write_string(&val.to_string()), + } + } + + /// Writes a timestamp value fxor the next column. pub fn write_timestamp(&mut self, val: NaiveDateTime) { match self.parent.format_code { FormatCode::Binary => {