diff --git a/Cargo.lock b/Cargo.lock index 11e5c65c0bd..bf245aef791 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10229,6 +10229,7 @@ dependencies = [ "tokio-stream", "tracing", "url", + "uuid", "vortex", "vortex-utils", "walkdir", diff --git a/vortex-datafusion/Cargo.toml b/vortex-datafusion/Cargo.toml index e50363a016d..1a8ac2f714c 100644 --- a/vortex-datafusion/Cargo.toml +++ b/vortex-datafusion/Cargo.toml @@ -36,6 +36,7 @@ object_store = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "fs"] } tokio-stream = { workspace = true } tracing = { workspace = true, features = ["std", "attributes"] } +uuid = { workspace = true, features = ["v7"] } vortex = { workspace = true, features = ["object_store", "tokio", "files"] } vortex-utils = { workspace = true, features = ["dashmap"] } diff --git a/vortex-datafusion/src/lib.rs b/vortex-datafusion/src/lib.rs index 15975d99157..a6b7b88ad05 100644 --- a/vortex-datafusion/src/lib.rs +++ b/vortex-datafusion/src/lib.rs @@ -48,3 +48,85 @@ where } } } + +#[cfg(test)] +mod common_tests { + use std::sync::Arc; + use std::sync::LazyLock; + + use datafusion::arrow::array::RecordBatch; + use datafusion::datasource::provider::DefaultTableFactory; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::SessionContext; + use datafusion_common::GetExt; + use object_store::ObjectStore; + use object_store::memory::InMemory; + use url::Url; + use vortex::VortexSessionDefault; + use vortex::array::ArrayRef; + use vortex::array::arrow::FromArrowArray; + use vortex::file::WriteOptionsSessionExt; + use vortex::io::ObjectStoreWriter; + use vortex::io::VortexWrite; + use vortex::session::VortexSession; + + use crate::VortexFormatFactory; + use crate::VortexOptions; + + static VX_SESSION: LazyLock = LazyLock::new(VortexSession::default); + + pub struct TestSessionContext { + pub store: Arc, + pub session: SessionContext, + } + + impl Default for TestSessionContext { + fn default() -> Self { + Self::new(false) + } + } + + impl TestSessionContext { + /// Create a new test session context with the given projection pushdown setting. + pub fn new(projection_pushdown: bool) -> Self { + let store = Arc::new(InMemory::new()); + let opts = VortexOptions { + projection_pushdown, + ..Default::default() + }; + let factory = Arc::new(VortexFormatFactory::new().with_options(opts)); + let mut session_state_builder = SessionStateBuilder::new() + .with_default_features() + .with_table_factory( + factory.get_ext().to_uppercase(), + Arc::new(DefaultTableFactory::new()), + ) + .with_object_store(&Url::try_from("file://").unwrap(), store.clone()); + + if let Some(file_formats) = session_state_builder.file_formats() { + file_formats.push(factory as _); + } + + let session: SessionContext = + SessionContext::new_with_state(session_state_builder.build()).enable_url_table(); + + Self { store, session } + } + + /// Write arrow data into a vortex file. + pub async fn write_arrow_batch

(&self, path: P, batch: &RecordBatch) -> anyhow::Result<()> + where + P: Into, + { + let array = ArrayRef::from_arrow(batch, false); + let mut write = ObjectStoreWriter::new(self.store.clone(), &path.into()).await?; + VX_SESSION + .write_options() + .write(&mut write, array.to_array_stream()) + .await?; + write.shutdown().await?; + + Ok(()) + } + } +} diff --git a/vortex-datafusion/src/persistent/format.rs b/vortex-datafusion/src/persistent/format.rs index f1b11e0a45b..682a9523058 100644 --- a/vortex-datafusion/src/persistent/format.rs +++ b/vortex-datafusion/src/persistent/format.rs @@ -17,6 +17,7 @@ use datafusion_common::Result as DFResult; use datafusion_common::Statistics; use datafusion_common::config::ConfigField; use datafusion_common::config_namespace; +use datafusion_common::internal_datafusion_err; use datafusion_common::not_impl_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; @@ -33,6 +34,8 @@ use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::LexRequirement; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use futures::FutureExt; use futures::StreamExt as _; use futures::TryStreamExt as _; @@ -64,6 +67,7 @@ use crate::PrecisionExt as _; use crate::convert::TryToDataFusion; const DEFAULT_FOOTER_INITIAL_READ_SIZE_BYTES: usize = MAX_POSTSCRIPT_SIZE as usize + EOF_SIZE; +const DEFAULT_TARGET_FILE_SIZE_MB: usize = 128; /// Vortex implementation of a DataFusion [`FileFormat`]. pub struct VortexFormat { @@ -96,6 +100,24 @@ config_namespace! { /// Values smaller than `MAX_POSTSCRIPT_SIZE + EOF_SIZE` will be clamped to that minimum /// during footer parsing. pub footer_initial_read_size_bytes: usize, default = DEFAULT_FOOTER_INITIAL_READ_SIZE_BYTES + /// Target file size in megabytes for written Vortex files. + /// + /// When greater than 0 for non-partitioned writes, Vortex bypasses + /// DataFusion's file demuxer and splits output files based on + /// approximate byte size rather than row count. + pub target_file_size_mb: usize, default = DEFAULT_TARGET_FILE_SIZE_MB + /// Whether to enable projection pushdown into the underlying Vortex scan. + /// + /// When enabled, projection expressions may be partially evaluated during + /// the scan. When disabled, Vortex reads only the referenced columns and + /// all expressions are evaluated after the scan. + pub projection_pushdown: bool, default = false + /// The intra-partition scan concurrency, controlling the number of row splits to process + /// concurrently per-thread within each file. + /// + /// This does not affect the overall parallelism + /// across partitions, which is controlled by DataFusion's execution configuration. + pub scan_concurrency: Option, default = None } } @@ -417,8 +439,42 @@ impl FileFormat for VortexFormat { return not_impl_err!("Overwrites are not implemented yet for Vortex"); } + let target_file_size = (self.opts.target_file_size_mb > 0) + .then(|| { + u64::try_from(self.opts.target_file_size_mb) + .map_err(|e| { + internal_datafusion_err!( + "target_file_size_mb cannot be represented as u64: {e}" + ) + }) + .map(|v| v.saturating_mul(1024 * 1024).max(1)) + }) + .transpose()?; + + // For non-partitioned writes, force a single input stream so VortexSink + // performs one coordinated write per statement instead of one + // independent write per CPU/input partition. + // + // Use coalescing rather than repartitioning to avoid introducing a + // shuffle/dispatcher step that can interleave batches from different + // input partitions. + // + // For partitioned writes, keep DataFusion's demuxer behavior. + let input: Arc = if conf.table_partition_cols.is_empty() + && input.output_partitioning().partition_count() > 1 + { + Arc::new(CoalescePartitionsExec::new(input)) + } else { + input + }; + let schema = conf.output_schema().clone(); - let sink = Arc::new(VortexSink::new(conf, schema, self.session.clone())); + let sink = Arc::new(VortexSink::new( + conf, + schema, + self.session.clone(), + target_file_size, + )); Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } diff --git a/vortex-datafusion/src/persistent/sink.rs b/vortex-datafusion/src/persistent/sink.rs index c5cdea8ee7f..95e4f2a8470 100644 --- a/vortex-datafusion/src/persistent/sink.rs +++ b/vortex-datafusion/src/persistent/sink.rs @@ -3,54 +3,135 @@ use std::any::Any; use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; +use arrow_schema::Schema; use arrow_schema::SchemaRef; use async_trait::async_trait; -use datafusion_common::DataFusionError; use datafusion_common::Result as DFResult; -use datafusion_common_runtime::JoinSet; -use datafusion_common_runtime::SpawnedTask; -use datafusion_datasource::file_sink_config::FileSink; +use datafusion_common::arrow::array::RecordBatch; +use datafusion_common::arrow::array::RecordBatchOptions; +use datafusion_common::exec_datafusion_err; +use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_sink_config::FileSinkConfig; use datafusion_datasource::sink::DataSink; -use datafusion_datasource::write::demux::DemuxedStreamReceiver; use datafusion_datasource::write::get_writer_schema; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::TaskContext; use datafusion_physical_plan::DisplayAs; use datafusion_physical_plan::DisplayFormatType; use datafusion_physical_plan::metrics::MetricsSet; +use futures::SinkExt; use futures::StreamExt; use object_store::ObjectStore; use object_store::path::Path; -use tokio_stream::wrappers::ReceiverStream; +use tokio::task::JoinHandle; +use uuid::Uuid; use vortex::array::ArrayRef; use vortex::array::arrow::FromArrowArray; use vortex::array::stream::ArrayStreamAdapter; use vortex::dtype::DType; use vortex::dtype::arrow::FromArrowType; -use vortex::error::VortexResult; use vortex::file::WriteOptionsSessionExt; +use vortex::file::WriteSummary; use vortex::io::ObjectStoreWriter; use vortex::io::VortexWrite; use vortex::session::VortexSession; +use vortex_utils::aliases::hash_set::HashSet; + +struct WriteOutputOptions<'a> { + base_output_path: &'a ListingTableUrl, + target_file_size: Option, + extension: &'a str, + write_id: &'a str, + partition_column_names: &'a [String], + keep_partition_by_columns: bool, +} + +#[derive(Clone, Copy)] +struct CompressionEstimate { + prev_compressed_bytes: u64, + prev_uncompressed_bytes: u64, +} + +impl CompressionEstimate { + fn identity() -> Self { + Self { + prev_compressed_bytes: 1, + prev_uncompressed_bytes: 1, + } + } + + fn from_file_sizes(compressed_bytes: u64, uncompressed_bytes: u64) -> DFResult { + if uncompressed_bytes == 0 { + return Err(exec_datafusion_err!( + "Cannot derive compression estimate from zero uncompressed bytes" + )); + } + + Ok(Self { + prev_compressed_bytes: compressed_bytes, + prev_uncompressed_bytes: uncompressed_bytes, + }) + } + + fn estimate_compressed_size(self, uncompressed_bytes: u64) -> DFResult { + if self.prev_uncompressed_bytes == 0 { + return Err(exec_datafusion_err!( + "Compression estimate denominator must be non-zero" + )); + } + + let estimated = u128::from(uncompressed_bytes) + .checked_mul(u128::from(self.prev_compressed_bytes)) + .ok_or_else(|| { + exec_datafusion_err!( + "Compressed size estimate overflow for {} * {}", + uncompressed_bytes, + self.prev_compressed_bytes + ) + })? + / u128::from(self.prev_uncompressed_bytes); + + u64::try_from(estimated).map_err(|_| { + exec_datafusion_err!("Compressed size estimate does not fit in u64: {estimated}") + }) + } +} + +struct ActiveFileWriter { + path: Path, + sender: futures::channel::mpsc::Sender, + task: JoinHandle>, +} pub struct VortexSink { config: FileSinkConfig, schema: SchemaRef, session: VortexSession, + target_file_size: Option, } impl VortexSink { - pub fn new(config: FileSinkConfig, schema: SchemaRef, session: VortexSession) -> Self { + pub fn new( + config: FileSinkConfig, + schema: SchemaRef, + session: VortexSession, + target_file_size: Option, + ) -> Self { Self { config, schema, session, + target_file_size, } } + + fn base_output_path(&self) -> DFResult<&ListingTableUrl> { + self.config + .table_paths + .first() + .ok_or_else(|| exec_datafusion_err!("Vortex sink requires at least one table path")) + } } impl std::fmt::Debug for VortexSink { @@ -91,94 +172,321 @@ impl DataSink for VortexSink { data: SendableRecordBatchStream, context: &Arc, ) -> DFResult { - FileSink::write_all(self, data, context).await + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + let writer_schema = get_writer_schema(&self.config); + let dtype = DType::from_arrow(writer_schema); + let write_id = Uuid::now_v7().simple().to_string(); + let base_output_path = self.base_output_path()?; + let partition_column_names = self + .config + .table_partition_cols + .iter() + .map(|(name, _)| name.clone()) + .collect::>(); + + let summaries = write_record_batch_stream_to_files( + self.session.clone(), + object_store, + dtype, + data, + &WriteOutputOptions { + base_output_path, + target_file_size: self.target_file_size, + extension: &self.config.file_extension, + write_id: &write_id, + partition_column_names: &partition_column_names, + keep_partition_by_columns: self.config.keep_partition_by_columns, + }, + ) + .await?; + + let mut row_count = 0_u64; + for (path, summary) in summaries { + row_count = row_count.checked_add(summary.row_count()).ok_or_else(|| { + exec_datafusion_err!( + "Row count overflow while aggregating sink summaries (current={}, file={})", + row_count, + summary.row_count() + ) + })?; + tracing::debug!(path = %path, "Successfully written file"); + } + + Ok(row_count) } } -#[async_trait] -impl FileSink for VortexSink { - fn config(&self) -> &FileSinkConfig { - &self.config - } +/// Write batches from a single input stream to one or more output files. +/// +/// For collection paths, files are emitted using the `{write_id}_{file_index:05}.{extension}` +/// naming scheme as produced by the underlying writer implementation. +/// For a single-file path, the original target path is used unless rolling is needed, +/// in which case additional files follow the same naming scheme. +async fn write_record_batch_stream_to_files( + session: VortexSession, + object_store: Arc, + dtype: DType, + mut data: SendableRecordBatchStream, + output_options: &WriteOutputOptions<'_>, +) -> DFResult> { + let target = output_options.target_file_size.map(|t| t.max(1)); + let single_file_output = !output_options.base_output_path.is_collection() + && output_options.base_output_path.file_extension().is_some(); - async fn spawn_writer_tasks_and_join( - &self, - _context: &Arc, - demux_task: SpawnedTask>, - mut file_stream_rx: DemuxedStreamReceiver, - object_store: Arc, - ) -> DFResult { - // This is a hack - let row_counter = Arc::new(AtomicU64::new(0)); - - let mut file_write_tasks: JoinSet> = JoinSet::new(); - - // TODO(adamg): - // 1. We can probably be better at signaling how much memory we're consuming (potentially when reading too), see ParquetSink::spawn_writer_tasks_and_join. - while let Some((path, rx)) = file_stream_rx.recv().await { - let session = self.session.clone(); - let row_counter = row_counter.clone(); - let object_store = object_store.clone(); - let writer_schema = get_writer_schema(&self.config); - let dtype = DType::from_arrow(writer_schema); - - // We need to spawn work because there's a dependency between the different files. If one file has too many batches buffered, - // the demux task might deadlock itself. - file_write_tasks.spawn(async move { - let stream = ReceiverStream::new(rx).map(move |rb| { - row_counter.fetch_add(rb.num_rows() as u64, Ordering::Relaxed); - VortexResult::Ok(ArrayRef::from_arrow(rb, false)) - }); - - let stream_adapter = ArrayStreamAdapter::new(dtype, stream); - - let mut sink = ObjectStoreWriter::new(object_store.clone(), &path) - .await - .map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create ObjectStoreWriter: {e}" - )) - })?; + let mut results: Vec<(Path, WriteSummary)> = Vec::new(); + let mut active_writer: Option = None; + let mut uncompressed_bytes_in_file = 0_u64; + let mut file_index = 0_usize; + let mut compression_estimate = CompressionEstimate::identity(); - session - .write_options() - .write(&mut sink, stream_adapter) - .await - .map_err(|e| { - DataFusionError::Execution(format!("Failed to write Vortex file: {e}")) - })?; + let write_result: DFResult<()> = async { + while let Some(batch) = data.next().await.transpose()? { + let batch = if output_options.keep_partition_by_columns + || output_options.partition_column_names.is_empty() + { + batch + } else { + remove_partition_columns(&batch, output_options.partition_column_names)? + }; - sink.shutdown().await.map_err(|e| { - DataFusionError::Execution(format!("Failed to shutdown Vortex writer: {e}")) - })?; + if active_writer.is_none() { + let file_path = output_file_path( + output_options.base_output_path, + file_index, + output_options.extension, + single_file_output, + output_options.write_id, + ); + active_writer = Some(start_file_writer( + &session, + object_store.clone(), + file_path, + dtype.clone(), + )); + } - Ok(path) - }); - } + let batch_bytes = batch_uncompressed_bytes(&batch)?; + let writer = active_writer.as_mut().ok_or_else(|| { + exec_datafusion_err!("Missing active file writer for sink output") + })?; + send_batch_to_active_writer(writer, batch).await?; + uncompressed_bytes_in_file = uncompressed_bytes_in_file + .checked_add(batch_bytes) + .ok_or_else(|| { + exec_datafusion_err!( + "Uncompressed byte counter overflow for output file {}", + writer.path + ) + })?; - while let Some(result) = file_write_tasks.join_next().await { - match result { - Ok(path) => { - let path = path?; - tracing::info!(path = %path, "Successfully written file"); - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); + if let Some(target) = target { + let estimated_compressed = + compression_estimate.estimate_compressed_size(uncompressed_bytes_in_file)?; + if estimated_compressed >= target { + let writer = active_writer.take().ok_or_else(|| { + exec_datafusion_err!( + "Missing active file writer while finalizing rotated output file" + ) + })?; + let file_path = writer.path.clone(); + let summary = finish_file_writer(writer).await?; + if uncompressed_bytes_in_file > 0 { + compression_estimate = CompressionEstimate::from_file_sizes( + summary.size(), + uncompressed_bytes_in_file, + )?; } + + results.push((file_path, summary)); + uncompressed_bytes_in_file = 0; + file_index += 1; } } } - demux_task - .join_unwind() + if let Some(writer) = active_writer.take() { + let file_path = writer.path.clone(); + let summary = finish_file_writer(writer).await?; + results.push((file_path, summary)); + } + + Ok(()) + } + .await; + + if let Err(err) = write_result { + cleanup_failed_write( + object_store, + active_writer.into_iter().collect(), + results.iter().map(|(path, _)| path.clone()).collect(), + ) + .await; + return Err(err); + } + + Ok(results) +} + +/// Generate a numbered file path from an existing path for size-based splitting. +/// +/// Given `base/file.vortex`, produces `base/file_00000.vortex`. +/// If the path has no recognized extension, appends `_00000.{extension}`. +fn numbered_path(original: &Path, index: usize, extension: &str) -> Path { + let s = original.to_string(); + let suffix = format!(".{extension}"); + if let Some(stem) = s.strip_suffix(&suffix) { + Path::from(format!("{stem}_{index:05}{suffix}")) + } else { + Path::from(format!("{s}_{index:05}.{extension}")) + } +} +fn start_file_writer( + session: &VortexSession, + object_store: Arc, + path: Path, + dtype: DType, +) -> ActiveFileWriter { + // Use a small bounded channel to enforce backpressure and avoid unbounded buffering. + let (sender, receiver) = futures::channel::mpsc::channel::(1); + let session = session.clone(); + let path_for_task = path.clone(); + + let task = tokio::spawn(async move { + let mut object_writer = ObjectStoreWriter::new(object_store, &path_for_task) + .await + .map_err(|e| { + exec_datafusion_err!( + "Failed to create ObjectStoreWriter for '{}': {e}", + path_for_task + ) + })?; + + let stream = + receiver.map(|rb| vortex::error::VortexResult::Ok(ArrayRef::from_arrow(rb, false))); + let stream_adapter = ArrayStreamAdapter::new(dtype, stream); + + let summary = session + .write_options() + .write(&mut object_writer, stream_adapter) .await - .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + .map_err(|e| { + exec_datafusion_err!("Failed to write Vortex file '{}': {e}", path_for_task) + })?; - Ok(row_counter.load(Ordering::SeqCst)) + object_writer.shutdown().await.map_err(|e| { + exec_datafusion_err!("Failed to shutdown Vortex writer '{}': {e}", path_for_task) + })?; + + Ok(summary) + }); + + ActiveFileWriter { path, sender, task } +} + +async fn cleanup_failed_write( + object_store: Arc, + active_writers: Vec, + finished_paths: Vec, +) { + let mut cleanup_paths = HashSet::new(); + + for writer in active_writers { + cleanup_paths.insert(writer.path.clone()); + writer.task.abort(); + drop(writer.task.await); + } + + for path in finished_paths { + cleanup_paths.insert(path); + } + + for path in cleanup_paths { + if let Err(e) = object_store.delete(&path).await { + tracing::warn!(path = %path, error = %e, "Failed to delete sink output during error cleanup"); + } + } +} + +async fn send_batch_to_active_writer( + writer: &mut ActiveFileWriter, + batch: RecordBatch, +) -> DFResult<()> { + writer.sender.send(batch).await.map_err(|e| { + exec_datafusion_err!( + "Failed to send batch to active writer for '{}': {e}", + writer.path + ) + }) +} + +async fn finish_file_writer(mut writer: ActiveFileWriter) -> DFResult { + writer.sender.close_channel(); + match writer.task.await { + Ok(result) => result, + Err(e) => Err(exec_datafusion_err!( + "Vortex writer task for '{}' failed to join: {e}", + writer.path + )), + } +} + +fn batch_uncompressed_bytes(batch: &RecordBatch) -> DFResult { + u64::try_from(batch.get_array_memory_size()).map_err(|_| { + exec_datafusion_err!( + "RecordBatch memory size does not fit in u64: {}", + batch.get_array_memory_size() + ) + }) +} + +fn remove_partition_columns( + batch: &RecordBatch, + partition_column_names: &[String], +) -> DFResult { + let partition_names: Vec<_> = partition_column_names.iter().map(String::as_str).collect(); + let (columns, fields): (Vec<_>, Vec<_>) = batch + .columns() + .iter() + .zip(batch.schema().fields()) + .filter(|(_, field)| !partition_names.contains(&field.name().as_str())) + .map(|(array, field)| (Arc::clone(array), (**field).clone())) + .unzip(); + + let schema = Schema::new(fields); + if columns.is_empty() { + let options = RecordBatchOptions::default().with_row_count(Some(batch.num_rows())); + return Ok(RecordBatch::try_new_with_options( + Arc::new(schema), + columns, + &options, + )?); + } + + Ok(RecordBatch::try_new(Arc::new(schema), columns)?) +} + +/// Build the output path for a rolling write. +fn output_file_path( + base_output_path: &ListingTableUrl, + file_index: usize, + extension: &str, + single_file_output: bool, + write_id: &str, +) -> Path { + if single_file_output { + if file_index == 0 { + return base_output_path.prefix().clone(); + } + return numbered_path(base_output_path.prefix(), file_index, extension); + } + + let mut base = base_output_path.prefix().to_string(); + if !base.ends_with('/') { + base.push('/'); } + Path::from(format!("{base}{write_id}_{file_index:05}.{extension}")) } #[cfg(test)] @@ -189,6 +497,7 @@ mod tests { use arrow_schema::Field; use arrow_schema::Schema; use datafusion::arrow::array::Int8Array; + use datafusion::arrow::array::Int64Array; use datafusion::arrow::array::RecordBatch; use datafusion::datasource::DefaultTableSource; use datafusion::execution::SessionStateBuilder; @@ -198,13 +507,22 @@ mod tests { use datafusion::logical_expr::Values; use datafusion::prelude::SessionContext; use datafusion_common::ScalarValue; + use datafusion_common::exec_datafusion_err; + use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_format::format_as_file_type; + use futures::TryStreamExt; use rstest::rstest; use tempfile::TempDir; + use tokio::sync::oneshot; + use tokio::time::Duration; use walkdir::WalkDir; + use crate::common_tests::TestSessionContext; use crate::persistent::VortexFormatFactory; + use crate::persistent::VortexOptions; use crate::persistent::register_vortex_format_factory; + use crate::persistent::sink::ActiveFileWriter; + use crate::persistent::sink::finish_file_writer; #[tokio::test] async fn test_insert_into() { @@ -439,4 +757,1037 @@ mod tests { Ok(()) } + + fn split_path( + base_path: &object_store::path::Path, + file_index: usize, + extension: &str, + ) -> object_store::path::Path { + let mut base = base_path.to_string(); + if !base.ends_with('/') { + base.push('/'); + } + let filename = format!("part-{file_index:05}.{extension}"); + object_store::path::Path::from(format!("{base}{filename}")) + } + + #[test] + fn test_split_path_basic() { + let path = object_store::path::Path::from("data/output"); + assert_eq!( + split_path(&path, 0, "vortex").to_string(), + "data/output/part-00000.vortex" + ); + assert_eq!( + split_path(&path, 12, "vortex").to_string(), + "data/output/part-00012.vortex" + ); + } + + #[test] + fn test_split_path_preserves_trailing_slash() { + let path = object_store::path::Path::from("nested/path/"); + assert_eq!( + split_path(&path, 3, "vx").to_string(), + "nested/path/part-00003.vx" + ); + } + + #[test] + fn test_numbered_path() { + use super::numbered_path; + + let path = object_store::path::Path::from("table/c1=alpha/abc123.vortex"); + assert_eq!( + numbered_path(&path, 0, "vortex").to_string(), + "table/c1=alpha/abc123_00000.vortex" + ); + assert_eq!( + numbered_path(&path, 5, "vortex").to_string(), + "table/c1=alpha/abc123_00005.vortex" + ); + } + + #[test] + fn test_numbered_path_no_extension() { + use super::numbered_path; + + let path = object_store::path::Path::from("table/output"); + assert_eq!( + numbered_path(&path, 0, "vortex").to_string(), + "table/output_00000.vortex" + ); + } + + #[test] + fn test_output_file_path_single_file_and_collection() { + use super::output_file_path; + + let single = ListingTableUrl::parse("file:///tmp/output.vortex").unwrap(); + assert_eq!( + output_file_path(&single, 0, "vortex", true, "wid").to_string(), + "tmp/output.vortex" + ); + assert_eq!( + output_file_path(&single, 2, "vortex", true, "wid").to_string(), + "tmp/output_00002.vortex" + ); + + let collection = ListingTableUrl::parse("file:///tmp/table/").unwrap(); + assert_eq!( + output_file_path(&collection, 3, "vortex", false, "wid").to_string(), + "tmp/table/wid_00003.vortex" + ); + } + + #[tokio::test] + async fn test_finish_file_writer_waits_for_task_completion() -> anyhow::Result<()> { + let (sender, receiver) = futures::channel::mpsc::channel::(1); + drop(receiver); + + let (gate_tx, gate_rx) = oneshot::channel::<()>(); + + let writer = ActiveFileWriter { + path: object_store::path::Path::from("table/pending.vortex"), + sender, + task: tokio::spawn(async move { + let _ = gate_rx.await; + Err(exec_datafusion_err!( + "synthetic writer failure after completion gate" + )) + }), + }; + + let mut finish_fut = Box::pin(finish_file_writer(writer)); + assert!( + tokio::time::timeout(Duration::from_millis(50), &mut finish_fut) + .await + .is_err(), + "finish_file_writer returned before writer task completed" + ); + + gate_tx + .send(()) + .map_err(|_| anyhow::anyhow!("failed to release writer completion gate in test"))?; + + let err = match finish_fut.await { + Ok(_) => { + return Err(anyhow::anyhow!( + "finish_file_writer unexpectedly succeeded after gate release" + )); + } + Err(err) => err, + }; + assert!( + err.to_string().contains("synthetic writer failure"), + "unexpected error: {err}" + ); + + Ok(()) + } + + #[test] + fn test_remove_partition_columns() -> anyhow::Result<()> { + use datafusion::arrow::array::StringArray; + + use super::remove_partition_columns; + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("part", DataType::Utf8, false), + Field::new("val", DataType::Int64, false), + ])), + vec![ + Arc::new(StringArray::from(vec!["x", "y"])), + Arc::new(Int64Array::from(vec![1, 2])), + ], + )?; + + let out = remove_partition_columns(&batch, &["part".to_string()])?; + assert_eq!(out.num_columns(), 1); + assert_eq!(out.schema().field(0).name(), "val"); + assert_eq!(out.num_rows(), 2); + + Ok(()) + } + + #[test] + fn test_remove_partition_columns_all_columns_partitioned() -> anyhow::Result<()> { + use datafusion::arrow::array::StringArray; + + use super::remove_partition_columns; + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("part", DataType::Utf8, false)])), + vec![Arc::new(StringArray::from(vec!["x", "y", "z"]))], + )?; + + let out = remove_partition_columns(&batch, &["part".to_string()])?; + assert_eq!(out.num_columns(), 0); + assert_eq!(out.num_rows(), 3); + + Ok(()) + } + + /// Generate `count` pseudo-random i64 values using a simple LCG. + /// These values resist compression (unlike sequential or modular data), + /// giving more realistic compressed file sizes. + fn pseudo_random_i64s(count: usize, seed: i64) -> Vec { + let mut v = seed; + (0..count) + .map(|_| { + v = v + .wrapping_mul(6_364_136_223_846_793_005) + .wrapping_add(1_442_695_040_888_963_407); + v + }) + .collect() + } + + /// Tests file splitting through the full DataFusion pipeline. + /// + /// Writes ~62MB of pseudo-random Int64 data (near 1:1 compression ratio) + /// via COPY TO with a 16MB target file size. Verifies that exactly 4 files + /// are produced and each file's compressed size is approximately 16MB. + /// + /// This exercises the complete COPY TO write path through DataFusion and + /// VortexSink, unlike a direct `write_stream_to_files` call. + #[tokio::test] + async fn test_file_splitting_62mb_into_4_files() -> anyhow::Result<()> { + use datafusion::datasource::MemTable; + use datafusion_datasource::file_format::format_as_file_type; + + let ctx = TestSessionContext::default(); + + let target_mb = 16_usize; + let opts = VortexOptions { + target_file_size_mb: target_mb, + ..Default::default() + }; + let factory = VortexFormatFactory::new().with_options(opts); + + let batch_rows = 8192_usize; + let total_elements = 62 * 1024 * 1024 / 8; // ~8,126,464 i64 values ≈ 62MB Arrow memory + let num_batches = total_elements / batch_rows; + let expected_total_rows = (num_batches * batch_rows) as i64; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut batches = Vec::new(); + for i in 0..num_batches { + let values = pseudo_random_i64s(batch_rows, (i * batch_rows) as i64); + batches.push(RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(values))], + )?); + } + + let table = MemTable::try_new(schema.clone(), vec![batches])?; + ctx.session.register_table("source", Arc::new(table))?; + + let source = ctx.session.table("source").await?; + let logical_plan = LogicalPlanBuilder::copy_to( + source.logical_plan().clone(), + "/table/".to_string(), + format_as_file_type(Arc::new(factory)), + Default::default(), + vec![], + )? + .build()?; + + ctx.session + .execute_logical_plan(logical_plan) + .await? + .collect() + .await?; + + let file_metas = ctx + .store + .list(Some(&"/table".into())) + .try_collect::>() + .await?; + + assert_eq!( + file_metas.len(), + 4, + "Expected 4 files for ~62MB data with {target_mb}MB target, got {} (sizes: {:?})", + file_metas.len(), + file_metas.iter().map(|m| m.size).collect::>() + ); + + let target_bytes = u64::try_from(target_mb * 1024 * 1024)?; + for meta in &file_metas { + assert!( + meta.size > target_bytes / 2, + "File {} is {}B, expected at least {}B (target/2)", + meta.location, + meta.size, + target_bytes / 2 + ); + } + + // Verify total row count. + let result = ctx + .session + .sql("SELECT COUNT(*) as cnt FROM '/table/'") + .await? + .collect() + .await?; + + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + /// Tests file splitting with compressible data through the full pipeline. + /// + /// Uses low-entropy Int64 values (repeating 0..255) which compress ~8:1 in + /// Vortex. With the current code that compares Arrow memory size against + /// `target_file_size`, files are split far too early, producing many tiny + /// compressed files instead of files that are close to the target. + /// + /// For ~32MB of Arrow data (~4MB compressed at 8:1) with a 1MB target: + /// - **Correct**: 4 files of ~1MB compressed each + /// - **Bug**: 32 files of ~0.125MB compressed each + #[tokio::test] + async fn test_file_splitting_compressible_data() -> anyhow::Result<()> { + use datafusion::datasource::MemTable; + use datafusion_datasource::file_format::format_as_file_type; + + let ctx = TestSessionContext::default(); + + let target_mb = 1_usize; + let opts = VortexOptions { + target_file_size_mb: target_mb, + ..Default::default() + }; + let factory = VortexFormatFactory::new().with_options(opts); + + // Generate low-entropy Int64 values: repeating 0..255. + // Arrow memory: 4M × 8 bytes = 32MB. + // Vortex compressed: each value only needs ~1 byte → ~4MB total. + let total_elements = 4_000_000_usize; + let batch_rows = 8192_usize; + let num_batches = total_elements / batch_rows; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut batches = Vec::new(); + for i in 0..num_batches { + let values: Vec = (0..batch_rows) + .map(|j| ((i * batch_rows + j) % 256) as i64) + .collect(); + batches.push(RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(values))], + )?); + } + + let table = MemTable::try_new(schema.clone(), vec![batches])?; + ctx.session.register_table("source", Arc::new(table))?; + + let source = ctx.session.table("source").await?; + let logical_plan = LogicalPlanBuilder::copy_to( + source.logical_plan().clone(), + "/table/".to_string(), + format_as_file_type(Arc::new(factory)), + Default::default(), + vec![], + )? + .build()?; + + ctx.session + .execute_logical_plan(logical_plan) + .await? + .collect() + .await?; + + let file_metas = ctx + .store + .list(Some(&"/table".into())) + .try_collect::>() + .await?; + + // With compressible data, there should be few files (not > 10). + // The buggy code produces many tiny files because it splits on Arrow + // memory (32MB / 1MB = 32 files) instead of compressed size (~4MB / 1MB = 4 files). + let total_compressed: u64 = file_metas.iter().map(|m| m.size).sum(); + let target_bytes = u64::try_from(target_mb * 1024 * 1024)?; + + // We should have at most ~(total_compressed / target) + 1 files, not + // ~(arrow_memory / target) files. + let max_expected = usize::try_from(total_compressed / target_bytes + 2)?; + assert!( + file_metas.len() <= max_expected, + "Too many files: got {} but total compressed is {}B with {}B target \ + (expected at most {max_expected}). Files are being split on Arrow memory \ + instead of compressed size. Sizes: {:?}", + file_metas.len(), + total_compressed, + target_bytes, + file_metas.iter().map(|m| m.size).collect::>() + ); + + // Every file except the first should be reasonably sized. The first + // file may be smaller because the compression ratio is unknown until + // the first write completes. + for meta in file_metas.iter().skip(1) { + assert!( + meta.size > target_bytes / 4, + "File {} is {}B, far below target {}B — splitting on Arrow memory, not compressed size", + meta.location, + meta.size, + target_bytes + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_write_large_batch_target_file_size_disabled() -> anyhow::Result<()> { + use datafusion::datasource::MemTable; + use datafusion_datasource::file_format::format_as_file_type; + + let ctx = TestSessionContext::default(); + + let opts = VortexOptions { + // Disable sink-side rolling/splitting. + target_file_size_mb: 0, + ..Default::default() + }; + let factory = VortexFormatFactory::new().with_options(opts); + + let rows_per_partition = 300_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + partitions.push(vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(values))], + )?]); + } + + let table = MemTable::try_new(schema, partitions)?; + ctx.session.register_table("source", Arc::new(table))?; + + let source = ctx.session.table("source").await?; + let logical_plan = LogicalPlanBuilder::copy_to( + source.logical_plan().clone(), + "/table/".to_string(), + format_as_file_type(Arc::new(factory)), + Default::default(), + vec![], + )? + .build()?; + + ctx.session + .execute_logical_plan(logical_plan) + .await? + .collect() + .await?; + + let file_metas = ctx + .store + .list(Some(&"/table".into())) + .try_collect::>() + .await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = file_metas + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id with target size disabled; got {:?} from files: {:?}", + unique_write_ids, + file_metas + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + + assert_eq!( + file_metas.len(), + 1, + "Expected exactly one output file when target size is disabled, got {}", + file_metas.len() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) as cnt FROM '/table/'") + .await? + .collect() + .await?; + + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + #[tokio::test] + async fn test_target_file_size_uses_single_sink_input_partition() -> anyhow::Result<()> { + use datafusion::datasource::MemTable; + use datafusion_datasource::file_format::format_as_file_type; + + let ctx = TestSessionContext::default(); + + let opts = VortexOptions { + // Enable sink-side sizing, but make the threshold large enough + // that all input data should fit in a single file. + target_file_size_mb: 512, + ..Default::default() + }; + let factory = VortexFormatFactory::new().with_options(opts); + + let rows_per_partition = 300_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Build a MemTable with multiple physical input partitions to mimic + // DataFusion's parallel writer inputs. + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + partitions.push(vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(values))], + )?]); + } + + let table = MemTable::try_new(schema, partitions)?; + ctx.session.register_table("source", Arc::new(table))?; + + let source = ctx.session.table("source").await?; + let logical_plan = LogicalPlanBuilder::copy_to( + source.logical_plan().clone(), + "/table/".to_string(), + format_as_file_type(Arc::new(factory)), + Default::default(), + vec![], + )? + .build()?; + + ctx.session + .execute_logical_plan(logical_plan) + .await? + .collect() + .await?; + + let file_metas = ctx + .store + .list(Some(&"/table".into())) + .try_collect::>() + .await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = file_metas + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id (single sink stream), got {:?} from files: {:?}", + unique_write_ids, + file_metas + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + + assert!( + file_metas.len() < num_partitions, + "Expected fewer output files than input partitions after coalescing; got {} files for {num_partitions} input partitions", + file_metas.len() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM '/table/'") + .await? + .collect() + .await?; + + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + #[tokio::test] + async fn test_insert_sql_target_size_multi_partition_source_single_write_id() + -> anyhow::Result<()> { + use datafusion::datasource::MemTable; + + let ctx = TestSessionContext::default(); + + ctx.session + .sql( + "CREATE EXTERNAL TABLE my_tbl \ + (a BIGINT NOT NULL) \ + STORED AS vortex \ + LOCATION 'table/' \ + OPTIONS(target_file_size_mb '64');", + ) + .await?; + + let rows_per_partition = 300_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + partitions.push(vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(values))], + )?]); + } + + let source = MemTable::try_new(schema, partitions)?; + ctx.session.register_table("source", Arc::new(source))?; + + ctx.session + .sql("INSERT INTO my_tbl SELECT a FROM source") + .await? + .collect() + .await?; + + let all_files = ctx.store.list(None).try_collect::>().await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = all_files + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id, got {:?} from files: {:?}", + unique_write_ids, + all_files + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + assert!( + all_files.len() < num_partitions, + "Expected fewer files than input partitions; got {} files for {num_partitions} input partitions", + all_files.len() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM my_tbl") + .await? + .collect() + .await?; + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + #[tokio::test] + async fn test_insert_sql_streaming_source_single_write_id() -> anyhow::Result<()> { + use arrow_schema::SchemaRef; + use datafusion::catalog::streaming::StreamingTable; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::streaming::PartitionStream; + use futures::stream; + + #[derive(Debug)] + struct StaticPartitionStream { + schema: SchemaRef, + batch: RecordBatch, + } + + impl PartitionStream for StaticPartitionStream { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute( + &self, + _ctx: Arc, + ) -> datafusion::physical_plan::SendableRecordBatchStream { + let schema = Arc::clone(&self.schema); + let batch = self.batch.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::iter(vec![Ok(batch)]), + )) + } + } + + let ctx = TestSessionContext::default(); + + ctx.session + .sql( + "CREATE EXTERNAL TABLE my_tbl \ + (a BIGINT NOT NULL) \ + STORED AS vortex \ + LOCATION 'table/' \ + OPTIONS(target_file_size_mb '64');", + ) + .await?; + + let rows_per_partition = 300_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))])?; + + partitions.push(Arc::new(StaticPartitionStream { + schema: schema.clone(), + batch, + })); + } + + let source = StreamingTable::try_new(schema, partitions)?; + ctx.session + .register_table("source_stream", Arc::new(source))?; + + ctx.session + .sql("INSERT INTO my_tbl SELECT a FROM source_stream") + .await? + .collect() + .await?; + + let all_files = ctx.store.list(None).try_collect::>().await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = all_files + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id for streaming source insert, got {:?} from files: {:?}", + unique_write_ids, + all_files + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM my_tbl") + .await? + .collect() + .await?; + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_direct_insert_into_streaming_exec_single_write_id() + -> anyhow::Result<()> { + use arrow_schema::SchemaRef; + use datafusion::physical_plan::ExecutionPlan; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::streaming::PartitionStream; + use datafusion::physical_plan::streaming::StreamingTableExec; + use datafusion_expr::dml::InsertOp; + use futures::stream; + + #[derive(Debug)] + struct StaticPartitionStream { + schema: SchemaRef, + batch: RecordBatch, + } + + impl PartitionStream for StaticPartitionStream { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute( + &self, + _ctx: Arc, + ) -> datafusion::physical_plan::SendableRecordBatchStream { + let schema = Arc::clone(&self.schema); + let batch = self.batch.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::iter(vec![Ok(batch)]), + )) + } + } + + let ctx = TestSessionContext::default(); + + ctx.session + .sql( + "CREATE EXTERNAL TABLE my_tbl \ + (a BIGINT NOT NULL) \ + STORED AS vortex \ + LOCATION 'table/' \ + OPTIONS(target_file_size_mb '64');", + ) + .await?; + + let table_provider = ctx.session.table_provider("my_tbl").await?; + + let rows_per_partition = 300_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))])?; + + partitions.push(Arc::new(StaticPartitionStream { + schema: schema.clone(), + batch, + })); + } + + let input = Arc::new(StreamingTableExec::try_new( + schema, + partitions, + None, + Vec::new(), + false, + None, + )?) as Arc; + + let plan = table_provider + .insert_into(&ctx.session.state(), input, InsertOp::Append) + .await?; + let _count_batches = collect(plan, ctx.session.task_ctx()).await?; + + let all_files = ctx.store.list(None).try_collect::>().await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = all_files + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id for direct insert_into streaming exec, got {:?} from files: {:?}", + unique_write_ids, + all_files + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM my_tbl") + .await? + .collect() + .await?; + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_direct_insert_into_unbounded_streaming_exec_single_write_id() + -> anyhow::Result<()> { + use arrow_schema::SchemaRef; + use datafusion::physical_plan::ExecutionPlan; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::streaming::PartitionStream; + use datafusion::physical_plan::streaming::StreamingTableExec; + use datafusion_expr::dml::InsertOp; + use futures::stream; + + #[derive(Debug)] + struct StaticPartitionStream { + schema: SchemaRef, + batch: RecordBatch, + } + + impl PartitionStream for StaticPartitionStream { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute( + &self, + _ctx: Arc, + ) -> datafusion::physical_plan::SendableRecordBatchStream { + let schema = Arc::clone(&self.schema); + let batch = self.batch.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::iter(vec![Ok(batch)]), + )) + } + } + + let ctx = TestSessionContext::default(); + + ctx.session + .sql( + "CREATE EXTERNAL TABLE my_tbl \ + (a BIGINT NOT NULL) \ + STORED AS vortex \ + LOCATION 'table/' \ + OPTIONS(target_file_size_mb '64');", + ) + .await?; + + let table_provider = ctx.session.table_provider("my_tbl").await?; + + let rows_per_partition = 100_000_usize; + let num_partitions = 8_usize; + let expected_total_rows = (rows_per_partition * num_partitions) as i64; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let mut partitions: Vec> = Vec::new(); + for p in 0..num_partitions { + let values = pseudo_random_i64s(rows_per_partition, (p * rows_per_partition) as i64); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))])?; + + partitions.push(Arc::new(StaticPartitionStream { + schema: schema.clone(), + batch, + })); + } + + let input = Arc::new(StreamingTableExec::try_new( + schema, + partitions, + None, + Vec::new(), + true, + None, + )?) as Arc; + + let plan = table_provider + .insert_into(&ctx.session.state(), input, InsertOp::Append) + .await?; + let _count_batches = collect(plan, ctx.session.task_ctx()).await?; + + let all_files = ctx.store.list(None).try_collect::>().await?; + + let unique_write_ids: vortex_utils::aliases::hash_set::HashSet<_> = all_files + .iter() + .filter_map(|m| { + m.location + .filename() + .and_then(|name| name.split_once('_')) + .map(|(prefix, _)| prefix.to_string()) + }) + .collect(); + + assert_eq!( + unique_write_ids.len(), + 1, + "Expected one write_id for unbounded streaming insert, got {:?} from files: {:?}", + unique_write_ids, + all_files + .iter() + .map(|m| format!("{}: {}B", m.location, m.size)) + .collect::>() + ); + + let result = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM my_tbl") + .await? + .collect() + .await?; + let count = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(count, expected_total_rows, "Total row count mismatch"); + + Ok(()) + } }