Skip to content

Commit 8d65084

Browse files
authored
feat: rollback to snapshot action (#8)
1 parent f3b98e0 commit 8d65084

2 files changed

Lines changed: 334 additions & 0 deletions

File tree

crates/iceberg/src/transaction/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ mod action;
5454

5555
pub use action::*;
5656
mod append;
57+
mod rollback_to_snapshot;
5758
mod snapshot;
5859
mod sort_order;
5960
mod update_location;
@@ -71,6 +72,7 @@ use crate::spec::TableProperties;
7172
use crate::table::Table;
7273
use crate::transaction::action::BoxedTransactionAction;
7374
use crate::transaction::append::FastAppendAction;
75+
use crate::transaction::rollback_to_snapshot::RollbackToSnapshotAction;
7476
use crate::transaction::sort_order::ReplaceSortOrderAction;
7577
use crate::transaction::update_location::UpdateLocationAction;
7678
use crate::transaction::update_properties::UpdatePropertiesAction;
@@ -146,6 +148,11 @@ impl Transaction {
146148
ReplaceSortOrderAction::new()
147149
}
148150

151+
/// Creates rollback to snapshot action.
152+
pub fn rollback_to_snapshot(&self) -> RollbackToSnapshotAction {
153+
RollbackToSnapshotAction::new()
154+
}
155+
149156
/// Set the location of table
150157
pub fn update_location(&self) -> UpdateLocationAction {
151158
UpdateLocationAction::new()
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use async_trait::async_trait;
21+
22+
use crate::error::Result;
23+
use crate::spec::{MAIN_BRANCH, SnapshotReference, SnapshotRetention};
24+
use crate::table::Table;
25+
use crate::transaction::action::{ActionCommit, TransactionAction};
26+
use crate::{Error, ErrorKind, TableRequirement, TableUpdate};
27+
28+
/// Transaction action that rolls back the table to a specific snapshot.
29+
#[derive(Default)]
30+
pub struct RollbackToSnapshotAction {
31+
snapshot_id: Option<i64>,
32+
}
33+
34+
impl RollbackToSnapshotAction {
35+
/// Creates a new [`RollbackToSnapshotAction`].
36+
pub fn new() -> Self {
37+
Self::default()
38+
}
39+
40+
/// Sets the target snapshot_id for this action.
41+
pub fn set_snapshot_id(mut self, snapshot_id: i64) -> Self {
42+
self.snapshot_id = Some(snapshot_id);
43+
self
44+
}
45+
}
46+
47+
#[async_trait]
48+
impl TransactionAction for RollbackToSnapshotAction {
49+
async fn commit(self: Arc<Self>, table: &Table) -> Result<ActionCommit> {
50+
let Some(snapshot_id) = self.snapshot_id else {
51+
return Err(Error::new(ErrorKind::DataInvalid, "snapshot id is not set"));
52+
};
53+
54+
table
55+
.metadata()
56+
.snapshots()
57+
.find(|s| s.snapshot_id() == snapshot_id)
58+
.ok_or_else(|| {
59+
Error::new(
60+
ErrorKind::DataInvalid,
61+
format!(
62+
"snapshot with id {} does not exist in the table",
63+
snapshot_id
64+
),
65+
)
66+
})?;
67+
68+
let reference =
69+
SnapshotReference::new(snapshot_id, SnapshotRetention::branch(None, None, None));
70+
71+
let updates = vec![TableUpdate::SetSnapshotRef {
72+
ref_name: MAIN_BRANCH.to_string(),
73+
reference,
74+
}];
75+
76+
let requirements = vec![
77+
TableRequirement::UuidMatch {
78+
uuid: table.metadata().uuid(),
79+
},
80+
TableRequirement::RefSnapshotIdMatch {
81+
r#ref: MAIN_BRANCH.to_string(),
82+
snapshot_id: table.metadata().current_snapshot_id(),
83+
},
84+
];
85+
86+
Ok(ActionCommit::new(updates, requirements))
87+
}
88+
}
89+
90+
#[cfg(test)]
91+
mod tests {
92+
use std::collections::HashMap;
93+
use std::sync::{Arc, LazyLock};
94+
95+
use arrow_array::cast::AsArray;
96+
use arrow_array::types::Int32Type;
97+
use arrow_array::{RecordBatch, record_batch};
98+
use futures::TryStreamExt;
99+
use itertools::Itertools;
100+
use uuid::Uuid;
101+
102+
use crate::arrow::schema_to_arrow_schema;
103+
use crate::memory::tests::new_memory_catalog;
104+
use crate::spec::{
105+
DataContentType, DataFileBuilder, DataFileFormat, Literal, MAIN_BRANCH, NestedField,
106+
PrimitiveType, Schema as IcebergSchema, SnapshotReference, SnapshotRetention, Struct, Type,
107+
};
108+
use crate::table::Table;
109+
use crate::transaction::tests::make_v3_minimal_table_in_catalog;
110+
use crate::transaction::{ApplyTransactionAction, Transaction, TransactionAction};
111+
use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
112+
use crate::writer::file_writer::ParquetWriterBuilder;
113+
use crate::writer::file_writer::location_generator::{
114+
DefaultFileNameGenerator, DefaultLocationGenerator,
115+
};
116+
use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
117+
use crate::writer::{IcebergWriter, IcebergWriterBuilder};
118+
use crate::{
119+
Catalog, NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate,
120+
};
121+
122+
static FILE_NAME_GENERATOR: LazyLock<DefaultFileNameGenerator> = LazyLock::new(|| {
123+
DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet)
124+
});
125+
126+
async fn write_and_commit(table: &Table, catalog: &dyn Catalog, batch: RecordBatch) -> Table {
127+
let iceberg_schema = table.metadata().current_schema();
128+
let arrow_schema = schema_to_arrow_schema(iceberg_schema).unwrap();
129+
let batch = batch.with_schema(Arc::new(arrow_schema)).unwrap();
130+
131+
let location_generator = DefaultLocationGenerator::new(table.metadata().clone()).unwrap();
132+
let parquet_writer_builder = ParquetWriterBuilder::new(
133+
parquet::file::properties::WriterProperties::default(),
134+
table.metadata().current_schema().clone(),
135+
);
136+
let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
137+
parquet_writer_builder,
138+
table.file_io().clone(),
139+
location_generator.clone(),
140+
FILE_NAME_GENERATOR.clone(),
141+
);
142+
let data_file_writer_builder = DataFileWriterBuilder::new(rolling_file_writer_builder);
143+
let mut data_file_writer = data_file_writer_builder.build(None).await.unwrap();
144+
data_file_writer.write(batch).await.unwrap();
145+
let data_file = data_file_writer.close().await.unwrap();
146+
147+
let tx = Transaction::new(table);
148+
let append_action = tx.fast_append().add_data_files(data_file);
149+
let tx = append_action.apply(tx).unwrap();
150+
tx.commit(catalog).await.unwrap()
151+
}
152+
153+
async fn get_batches(table: &Table) -> Vec<RecordBatch> {
154+
let batch_stream = table
155+
.scan()
156+
.select_all()
157+
.build()
158+
.unwrap()
159+
.to_arrow()
160+
.await
161+
.unwrap();
162+
batch_stream.try_collect().await.unwrap()
163+
}
164+
165+
#[tokio::test]
166+
async fn test_rollback_to_snapshot() {
167+
let catalog = new_memory_catalog().await;
168+
let namespace_ident = NamespaceIdent::new(format!("ns-{}", Uuid::new_v4()));
169+
let table_ident =
170+
TableIdent::new(namespace_ident.clone(), format!("table-{}", Uuid::new_v4()));
171+
172+
let schema = IcebergSchema::builder()
173+
.with_schema_id(1)
174+
.with_fields(vec![
175+
NestedField::optional(0, "id", Type::Primitive(PrimitiveType::Int)).into(),
176+
])
177+
.build()
178+
.unwrap();
179+
180+
let table_creation = TableCreation::builder()
181+
.name(table_ident.name.clone())
182+
.schema(schema)
183+
.build();
184+
185+
catalog
186+
.create_namespace(&namespace_ident, HashMap::new())
187+
.await
188+
.unwrap();
189+
190+
let table = catalog
191+
.create_table(&namespace_ident, table_creation)
192+
.await
193+
.unwrap();
194+
195+
let get_id_columns = |batches: &[RecordBatch]| {
196+
batches
197+
.iter()
198+
.flat_map(|b| {
199+
b.columns()
200+
.iter()
201+
.flat_map(|c| c.as_primitive::<Int32Type>().values())
202+
.copied()
203+
})
204+
.sorted()
205+
.collect::<Vec<_>>()
206+
};
207+
208+
let insert_batch = record_batch!(("id", Int32, [1, 2])).unwrap();
209+
let table = write_and_commit(&table, &catalog, insert_batch).await;
210+
let snapshot_id_1 = table.metadata().current_snapshot_id().unwrap();
211+
let batch_1 = get_batches(&table).await;
212+
let ids = get_id_columns(&batch_1);
213+
assert_eq!(ids, [1, 2]);
214+
215+
let insert_batch = record_batch!(("id", Int32, [3, 4])).unwrap();
216+
let table = write_and_commit(&table, &catalog, insert_batch).await;
217+
let snapshot_id_2 = table.metadata().current_snapshot_id().unwrap();
218+
let batch_2 = get_batches(&table).await;
219+
let ids = get_id_columns(&batch_2);
220+
assert_eq!(ids, [1, 2, 3, 4]);
221+
222+
let tx = Transaction::new(&table);
223+
let action = tx
224+
.rollback_to_snapshot()
225+
.set_snapshot_id(snapshot_id_1)
226+
.apply(tx)
227+
.unwrap();
228+
let table = action.commit(&catalog).await.unwrap();
229+
assert_eq!(table.metadata().current_snapshot_id(), Some(snapshot_id_1));
230+
231+
let batch_after_rollback = get_batches(&table).await;
232+
let ids = get_id_columns(&batch_after_rollback);
233+
assert_eq!(ids, [1, 2]);
234+
235+
let insert_batch = record_batch!(("id", Int32, [5, 6])).unwrap();
236+
let table = write_and_commit(&table, &catalog, insert_batch).await;
237+
let snapshot_id_3 = table.metadata().current_snapshot_id().unwrap();
238+
assert_ne!(snapshot_id_3, snapshot_id_2);
239+
240+
let batch_3 = get_batches(&table).await;
241+
let ids = get_id_columns(&batch_3);
242+
assert_eq!(ids, [1, 2, 5, 6]);
243+
244+
let tx = Transaction::new(&table);
245+
let action = tx
246+
.rollback_to_snapshot()
247+
.set_snapshot_id(snapshot_id_2)
248+
.apply(tx)
249+
.unwrap();
250+
let table = action.commit(&catalog).await.unwrap();
251+
assert_eq!(table.metadata().current_snapshot_id(), Some(snapshot_id_2));
252+
253+
let batch_after_rollback = get_batches(&table).await;
254+
let ids = get_id_columns(&batch_after_rollback);
255+
assert_eq!(ids, [1, 2, 3, 4]);
256+
257+
let tx = Transaction::new(&table);
258+
let action = tx
259+
.rollback_to_snapshot()
260+
.set_snapshot_id(snapshot_id_3)
261+
.apply(tx)
262+
.unwrap();
263+
let table = action.commit(&catalog).await.unwrap();
264+
assert_eq!(table.metadata().current_snapshot_id(), Some(snapshot_id_3));
265+
266+
let batch_after_rollback = get_batches(&table).await;
267+
let ids = get_id_columns(&batch_after_rollback);
268+
assert_eq!(ids, [1, 2, 5, 6]);
269+
}
270+
271+
async fn insert_data(catalog: &dyn Catalog, table_ident: &TableIdent) -> Table {
272+
let table = catalog.load_table(table_ident).await.unwrap();
273+
let data_file = DataFileBuilder::default()
274+
.content(DataContentType::Data)
275+
.file_path(format!("test/{}.parquet", Uuid::new_v4()))
276+
.file_format(DataFileFormat::Parquet)
277+
.file_size_in_bytes(100)
278+
.record_count(1)
279+
.partition_spec_id(table.metadata().default_partition_spec_id())
280+
.partition(Struct::from_iter([Some(Literal::long(100))]))
281+
.build()
282+
.unwrap();
283+
284+
let tx = Transaction::new(&table);
285+
let tx = tx
286+
.fast_append()
287+
.add_data_files(vec![data_file])
288+
.apply(tx)
289+
.unwrap();
290+
291+
tx.commit(catalog).await.unwrap()
292+
}
293+
294+
#[tokio::test]
295+
async fn test_rollback_to_snapshot_build() {
296+
let catalog = new_memory_catalog().await;
297+
let table = make_v3_minimal_table_in_catalog(&catalog).await;
298+
let table = insert_data(&catalog, table.identifier()).await;
299+
let snapshot_id = table.metadata().current_snapshot().unwrap().snapshot_id();
300+
301+
let tx = Transaction::new(&table);
302+
let action = tx.rollback_to_snapshot().set_snapshot_id(snapshot_id);
303+
assert_eq!(action.snapshot_id, Some(snapshot_id));
304+
305+
let mut action_commit = Arc::new(action).commit(&table).await.unwrap();
306+
let updates = action_commit.take_updates();
307+
let requirements = action_commit.take_requirements();
308+
309+
let reference =
310+
SnapshotReference::new(snapshot_id, SnapshotRetention::branch(None, None, None));
311+
312+
assert_eq!(updates, vec![TableUpdate::SetSnapshotRef {
313+
ref_name: MAIN_BRANCH.to_string(),
314+
reference,
315+
}],);
316+
317+
assert_eq!(requirements, vec![
318+
TableRequirement::UuidMatch {
319+
uuid: table.metadata().uuid(),
320+
},
321+
TableRequirement::RefSnapshotIdMatch {
322+
r#ref: MAIN_BRANCH.to_string(),
323+
snapshot_id: table.metadata().current_snapshot_id(),
324+
},
325+
])
326+
}
327+
}

0 commit comments

Comments
 (0)