Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/src/duckdb/creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@ impl TableDefinition {
.map(|(name, time_created)| (RelationName(name), time_created))
.collect())
}

/// Resolve the actual table name for DML operations (DELETE, UPDATE).
///
/// If the table is backed by a view over an internal
/// `__data_*` table, returns the latest internal table name.
/// Otherwise returns the base table definition name.
///
/// # Errors
///
/// Returns an error if the internal tables cannot be listed.
pub fn resolve_dml_table_name(&self, tx: &Transaction<'_>) -> super::Result<String> {
let internal_tables = self.list_internal_tables(tx)?;
if let Some((latest_internal_table_name, _)) = internal_tables.last() {
Ok(latest_internal_table_name.to_string())
} else {
Ok(self.name.to_string())
}
}
}

/// A table creator, which is used to create, delete, and manage tables based on a `TableDefinition`.
Expand Down
166 changes: 152 additions & 14 deletions core/src/duckdb/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ impl TableProvider for DuckDBTableWriter {
} else {
Some(filters_to_sql(&filters, Some(expr::Engine::DuckDB))?)
};
let table_name = self.table_definition.name().to_string();
let table_definition = Arc::clone(&self.table_definition);
let pool = Arc::clone(&self.pool);

Ok(Arc::new(DeletionExec::new(Arc::new(DuckDBDeletionSink {
pool,
table_name,
table_definition,
sql_where,
}))))
}
Expand All @@ -281,34 +281,34 @@ impl TableProvider for DuckDBTableWriter {
}

let set_clause = assignments_to_sql(&assignments, Some(expr::Engine::DuckDB))?;
let table_name = self.table_definition.name().to_string();
let pool = Arc::clone(&self.pool);

let sql = if filters.is_empty() {
format!(r#"UPDATE "{table_name}" SET {set_clause}"#)
let sql_where = if filters.is_empty() {
None
} else {
let sql_where = filters_to_sql(&filters, Some(expr::Engine::DuckDB))?;
format!(r#"UPDATE "{table_name}" SET {set_clause} WHERE {sql_where}"#)
Some(filters_to_sql(&filters, Some(expr::Engine::DuckDB))?)
};
let table_definition = Arc::clone(&self.table_definition);
let pool = Arc::clone(&self.pool);

Ok(Arc::new(UpdateExec::new(Arc::new(DuckDBUpdateSink {
pool,
sql,
table_definition,
set_clause,
sql_where,
}))))
}
}

struct DuckDBDeletionSink {
pool: Arc<DuckDbConnectionPool>,
table_name: String,
table_definition: Arc<TableDefinition>,
sql_where: Option<String>,
}

#[async_trait]
impl DeletionSink for DuckDBDeletionSink {
async fn delete_from(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
let pool = Arc::clone(&self.pool);
let table_name = self.table_name.clone();
let table_definition = Arc::clone(&self.table_definition);
let sql_where = self.sql_where.clone();

tokio::task::spawn_blocking(
Expand All @@ -317,6 +317,8 @@ impl DeletionSink for DuckDBDeletionSink {
let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?;
let tx = duckdb_conn.conn.transaction()?;

let table_name = table_definition.resolve_dml_table_name(&tx)?;

let delete_sql = if let Some(sql_where) = &sql_where {
format!(r#"DELETE FROM "{table_name}" WHERE {sql_where}"#)
} else {
Expand All @@ -335,21 +337,32 @@ impl DeletionSink for DuckDBDeletionSink {

struct DuckDBUpdateSink {
pool: Arc<DuckDbConnectionPool>,
sql: String,
table_definition: Arc<TableDefinition>,
set_clause: String,
sql_where: Option<String>,
}

#[async_trait]
impl UpdateSink for DuckDBUpdateSink {
async fn execute_update(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
let pool = Arc::clone(&self.pool);
let sql = self.sql.clone();
let table_definition = Arc::clone(&self.table_definition);
let set_clause = self.set_clause.clone();
let sql_where = self.sql_where.clone();

tokio::task::spawn_blocking(
move || -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
let mut db_conn = pool.connect_sync()?;
let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?;
let tx = duckdb_conn.conn.transaction()?;

let table_name = table_definition.resolve_dml_table_name(&tx)?;

let sql = if let Some(sql_where) = &sql_where {
format!(r#"UPDATE "{table_name}" SET {set_clause} WHERE {sql_where}"#)
} else {
format!(r#"UPDATE "{table_name}" SET {set_clause}"#)
};
let count = tx.execute(&sql, [])?;

tx.commit()?;
Expand Down Expand Up @@ -1871,4 +1884,129 @@ mod test {
assert_eq!(name, "all");
}
}

/// Helper: set up a DuckDB table via Overwrite (which creates a view over
/// an internal `__data_*` table), then return `(DuckDBTableWriter, pool)`.
async fn setup_writer_with_overwrite_data(
ids: Vec<i64>,
names: Vec<&str>,
) -> (DuckDBTableWriter, Arc<DuckDbConnectionPool>) {
let pool = get_mem_duckdb();
let table_definition = get_basic_table_definition();

// Insert seed data via DuckDBDataSink with Overwrite mode.
// This creates an internal __data_* table and a view with the table definition name.
let schema = table_definition.schema();
let duckdb_sink = DuckDBDataSink::new(
Arc::clone(&pool),
Arc::clone(&table_definition),
InsertOp::Overwrite,
None,
Arc::clone(&schema),
);
let batches = vec![RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(ids)),
Arc::new(StringArray::from(names)),
],
)
.expect("should create a record batch")];
let stream = Box::pin(
MemoryStream::try_new(batches, Arc::clone(&schema), None).expect("to get stream"),
);
Arc::new(duckdb_sink)
.write_all(stream, &Arc::new(TaskContext::default()))
.await
.expect("to write all");

let mem_table: Arc<dyn TableProvider> = Arc::new(
datafusion::datasource::MemTable::try_new(schema, vec![vec![]])
.expect("to create mem table"),
);
let writer = DuckDBTableWriterBuilder::new()
.with_read_provider(mem_table)
.with_pool(Arc::clone(&pool))
.with_table_definition((*table_definition).clone())
.build()
.expect("to build writer");

(writer, pool)
}

#[tokio::test]
async fn test_delete_from_view_backed_table_with_filter() {
let _guard = init_tracing(None);
let (writer, pool) =
setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await;

let ctx = datafusion::prelude::SessionContext::new();

// DELETE WHERE id = 2
let filters =
vec![datafusion::logical_expr::col("id").eq(datafusion::logical_expr::lit(2i64))];
let plan = writer
.delete_from(&ctx.state(), filters)
.await
.expect("delete_from should succeed");

let count = extract_count(plan).await;
assert_eq!(count, 1, "should have deleted exactly 1 row");

// Verify remaining rows via the view
let (ids, names) = query_all_rows(&pool);
assert_eq!(ids.len(), 2, "should have 2 rows remaining");
assert_eq!(ids, vec![1, 3]);
assert_eq!(names, vec!["a", "c"]);
}

#[tokio::test]
async fn test_delete_from_view_backed_table_empty_filters() {
let _guard = init_tracing(None);
let (writer, pool) =
setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await;

let ctx = datafusion::prelude::SessionContext::new();
let plan = writer
.delete_from(&ctx.state(), vec![])
.await
.expect("delete_from should succeed");

let count = extract_count(plan).await;
assert_eq!(count, 3, "should have deleted all 3 rows");

let (ids, _) = query_all_rows(&pool);
assert!(ids.is_empty(), "table should be empty after delete-all");
}

#[tokio::test]
async fn test_update_view_backed_table_with_filter() {
let _guard = init_tracing(None);
let (writer, pool) =
setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await;

let ctx = datafusion::prelude::SessionContext::new();

let assignments = vec![("name".to_string(), datafusion::logical_expr::lit("updated"))];
let filters =
vec![datafusion::logical_expr::col("id").eq(datafusion::logical_expr::lit(2i64))];
let plan = writer
.update(&ctx.state(), assignments, filters)
.await
.expect("update should succeed");

let count = extract_count(plan).await;
assert_eq!(count, 1, "should have updated exactly 1 row");

let (ids, names) = query_all_rows(&pool);
assert_eq!(ids.len(), 3, "should still have 3 rows");
for (id, name) in ids.iter().zip(names.iter()) {
match *id {
1 => assert_eq!(name, "a"),
2 => assert_eq!(name, "updated"),
3 => assert_eq!(name, "c"),
other => panic!("unexpected id {other}"),
}
}
}
}
Loading