@@ -329,6 +329,14 @@ def update_schema(self) -> UpdateSchema:
329329 """
330330 return UpdateSchema (self ._table , self )
331331
332+ def update_snapshot (self ) -> UpdateSnapshot :
333+ """Create a new UpdateSnapshot to produce a new snapshot for the table.
334+
335+ Returns:
336+ A new UpdateSnapshot
337+ """
338+ return UpdateSnapshot (self ._table , self )
339+
332340 def remove_properties (self , * removals : str ) -> Transaction :
333341 """Remove properties.
334342
@@ -351,6 +359,12 @@ def update_location(self, location: str) -> Transaction:
351359 """
352360 raise NotImplementedError ("Not yet implemented" )
353361
362+ def schema (self ) -> Schema :
363+ try :
364+ return next (update for update in self ._updates if isinstance (update , AddSchemaUpdate )).schema_
365+ except StopIteration :
366+ return self ._table .schema ()
367+
354368 def commit_transaction (self ) -> Table :
355369 """Commit the changes to the catalog.
356370
@@ -965,8 +979,21 @@ def history(self) -> List[SnapshotLogEntry]:
965979 return self .metadata .snapshot_log
966980
967981 def update_schema (self , allow_incompatible_changes : bool = False , case_sensitive : bool = True ) -> UpdateSchema :
982+ """Create a new UpdateSchema to alter the columns of this table.
983+
984+ Returns:
985+ A new UpdateSchema.
986+ """
968987 return UpdateSchema (self , allow_incompatible_changes = allow_incompatible_changes , case_sensitive = case_sensitive )
969988
989+ def update_snapshot (self ) -> UpdateSnapshot :
990+ """Create a new UpdateSnapshot to produce a new snapshot for the table.
991+
992+ Returns:
993+ A new UpdateSnapshot
994+ """
995+ return UpdateSnapshot (self )
996+
970997 def name_mapping (self ) -> NameMapping :
971998 """Return the table's field-id NameMapping."""
972999 if name_mapping_json := self .properties .get (TableProperties .DEFAULT_NAME_MAPPING ):
@@ -976,7 +1003,7 @@ def name_mapping(self) -> NameMapping:
9761003
9771004 def append (self , df : pa .Table ) -> None :
9781005 """
979- Append data to the table.
1006+ Shorthand API for appending a PyArrow table to the table.
9801007
9811008 Args:
9821009 df: The Arrow dataframe that will be appended to overwrite the table
@@ -992,19 +1019,16 @@ def append(self, df: pa.Table) -> None:
9921019 if len (self .spec ().fields ) > 0 :
9931020 raise ValueError ("Cannot write to partitioned tables" )
9941021
995- merge = _MergingSnapshotProducer (operation = Operation .APPEND , table = self )
996-
997- # skip writing data files if the dataframe is empty
998- if df .shape [0 ] > 0 :
999- data_files = _dataframe_to_data_files (self , df = df )
1000- for data_file in data_files :
1001- merge .append_data_file (data_file )
1002-
1003- merge .commit ()
1022+ with self .update_snapshot ().fast_append () as update_snapshot :
1023+ # skip writing data files if the dataframe is empty
1024+ if df .shape [0 ] > 0 :
1025+ data_files = _dataframe_to_data_files (self , df = df )
1026+ for data_file in data_files :
1027+ update_snapshot .append_data_file (data_file )
10041028
10051029 def overwrite (self , df : pa .Table , overwrite_filter : BooleanExpression = ALWAYS_TRUE ) -> None :
10061030 """
1007- Overwrite all the data in the table.
1031+ Shorthand for overwriting the table with a PyArrow table.
10081032
10091033 Args:
10101034 df: The Arrow dataframe that will be used to overwrite the table
@@ -1025,18 +1049,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10251049 if len (self .spec ().fields ) > 0 :
10261050 raise ValueError ("Cannot write to partitioned tables" )
10271051
1028- merge = _MergingSnapshotProducer (
1029- operation = Operation .OVERWRITE if self .current_snapshot () is not None else Operation .APPEND ,
1030- table = self ,
1031- )
1032-
1033- # skip writing data files if the dataframe is empty
1034- if df .shape [0 ] > 0 :
1035- data_files = _dataframe_to_data_files (self , df = df )
1036- for data_file in data_files :
1037- merge .append_data_file (data_file )
1038-
1039- merge .commit ()
1052+ with self .update_snapshot ().overwrite () as update_snapshot :
1053+ # skip writing data files if the dataframe is empty
1054+ if df .shape [0 ] > 0 :
1055+ data_files = _dataframe_to_data_files (self , df = df )
1056+ for data_file in data_files :
1057+ update_snapshot .append_data_file (data_file )
10401058
10411059 def refs (self ) -> Dict [str , SnapshotRef ]:
10421060 """Return the snapshot references in the table."""
@@ -2331,7 +2349,12 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
23312349 return f'{ location } /metadata/snap-{ snapshot_id } -{ attempt } -{ commit_uuid } .avro'
23322350
23332351
2334- def _dataframe_to_data_files (table : Table , df : pa .Table ) -> Iterable [DataFile ]:
2352+ def _dataframe_to_data_files (table : Table , df : pa .Table , file_schema : Optional [Schema ] = None ) -> Iterable [DataFile ]:
2353+ """Convert a PyArrow table into a DataFile.
2354+
2355+ Returns:
2356+ An iterable that supplies datafiles that represent the table.
2357+ """
23352358 from pyiceberg .io .pyarrow import write_file
23362359
23372360 if len (table .spec ().fields ) > 0 :
@@ -2342,7 +2365,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
23422365
23432366 # This is an iter, so we don't have to materialize everything every time
23442367 # This will be more relevant when we start doing partitioned writes
2345- yield from write_file (table , iter ([WriteTask (write_uuid , next (counter ), df )]))
2368+ yield from write_file (table , iter ([WriteTask (write_uuid , next (counter ), df )]), file_schema = file_schema )
23462369
23472370
23482371class _MergingSnapshotProducer :
@@ -2352,55 +2375,35 @@ class _MergingSnapshotProducer:
23522375 _parent_snapshot_id : Optional [int ]
23532376 _added_data_files : List [DataFile ]
23542377 _commit_uuid : uuid .UUID
2378+ _transaction : Optional [Transaction ]
23552379
2356- def __init__ (self , operation : Operation , table : Table ) -> None :
2380+ def __init__ (self , operation : Operation , table : Table , transaction : Optional [ Transaction ] = None ) -> None :
23572381 self ._operation = operation
23582382 self ._table = table
23592383 self ._snapshot_id = table .new_snapshot_id ()
23602384 # Since we only support the main branch for now
23612385 self ._parent_snapshot_id = snapshot .snapshot_id if (snapshot := self ._table .current_snapshot ()) else None
23622386 self ._added_data_files = []
23632387 self ._commit_uuid = uuid .uuid4 ()
2388+ self ._transaction = transaction
2389+
2390+ def __enter__ (self ) -> _MergingSnapshotProducer :
2391+ """Start a transaction to update the table."""
2392+ return self
2393+
2394+ def __exit__ (self , _ : Any , value : Any , traceback : Any ) -> None :
2395+ """Close and commit the transaction."""
2396+ self .commit ()
23642397
23652398 def append_data_file (self , data_file : DataFile ) -> _MergingSnapshotProducer :
23662399 self ._added_data_files .append (data_file )
23672400 return self
23682401
2369- def _deleted_entries (self ) -> List [ManifestEntry ]:
2370- """To determine if we need to record any deleted entries.
2371-
2372- With partial overwrites we have to use the predicate to evaluate
2373- which entries are affected.
2374- """
2375- if self ._operation == Operation .OVERWRITE :
2376- if self ._parent_snapshot_id is not None :
2377- previous_snapshot = self ._table .snapshot_by_id (self ._parent_snapshot_id )
2378- if previous_snapshot is None :
2379- # This should never happen since you cannot overwrite an empty table
2380- raise ValueError (f"Could not find the previous snapshot: { self ._parent_snapshot_id } " )
2381-
2382- executor = ExecutorFactory .get_or_create ()
2383-
2384- def _get_entries (manifest : ManifestFile ) -> List [ManifestEntry ]:
2385- return [
2386- ManifestEntry (
2387- status = ManifestEntryStatus .DELETED ,
2388- snapshot_id = entry .snapshot_id ,
2389- data_sequence_number = entry .data_sequence_number ,
2390- file_sequence_number = entry .file_sequence_number ,
2391- data_file = entry .data_file ,
2392- )
2393- for entry in manifest .fetch_manifest_entry (self ._table .io , discard_deleted = True )
2394- if entry .data_file .content == DataFileContent .DATA
2395- ]
2402+ @abstractmethod
2403+ def _deleted_entries (self ) -> List [ManifestEntry ]: ...
23962404
2397- list_of_entries = executor .map (_get_entries , previous_snapshot .manifests (self ._table .io ))
2398- return list (chain (* list_of_entries ))
2399- return []
2400- elif self ._operation == Operation .APPEND :
2401- return []
2402- else :
2403- raise ValueError (f"Not implemented for: { self ._operation } " )
2405+ @abstractmethod
2406+ def _existing_manifests (self ) -> List [ManifestFile ]: ...
24042407
24052408 def _manifests (self ) -> List [ManifestFile ]:
24062409 def _write_added_manifest () -> List [ManifestFile ]:
@@ -2430,7 +2433,7 @@ def _write_added_manifest() -> List[ManifestFile]:
24302433 def _write_delete_manifest () -> List [ManifestFile ]:
24312434 # Check if we need to mark the files as deleted
24322435 deleted_entries = self ._deleted_entries ()
2433- if deleted_entries :
2436+ if len ( deleted_entries ) > 0 :
24342437 output_file_location = _new_manifest_path (location = self ._table .location (), num = 1 , commit_uuid = self ._commit_uuid )
24352438 with write_manifest (
24362439 format_version = self ._table .format_version ,
@@ -2445,32 +2448,11 @@ def _write_delete_manifest() -> List[ManifestFile]:
24452448 else :
24462449 return []
24472450
2448- def _fetch_existing_manifests () -> List [ManifestFile ]:
2449- existing_manifests = []
2450-
2451- # Add existing manifests
2452- if self ._operation == Operation .APPEND and self ._parent_snapshot_id is not None :
2453- # In case we want to append, just add the existing manifests
2454- previous_snapshot = self ._table .snapshot_by_id (self ._parent_snapshot_id )
2455-
2456- if previous_snapshot is None :
2457- raise ValueError (f"Snapshot could not be found: { self ._parent_snapshot_id } " )
2458-
2459- for manifest in previous_snapshot .manifests (io = self ._table .io ):
2460- if (
2461- manifest .has_added_files ()
2462- or manifest .has_existing_files ()
2463- or manifest .added_snapshot_id == self ._snapshot_id
2464- ):
2465- existing_manifests .append (manifest )
2466-
2467- return existing_manifests
2468-
24692451 executor = ExecutorFactory .get_or_create ()
24702452
24712453 added_manifests = executor .submit (_write_added_manifest )
24722454 delete_manifests = executor .submit (_write_delete_manifest )
2473- existing_manifests = executor .submit (_fetch_existing_manifests )
2455+ existing_manifests = executor .submit (self . _existing_manifests )
24742456
24752457 return added_manifests .result () + delete_manifests .result () + existing_manifests .result ()
24762458
@@ -2515,10 +2497,107 @@ def commit(self) -> Snapshot:
25152497 schema_id = self ._table .schema ().schema_id ,
25162498 )
25172499
2518- with self ._table . transaction () as tx :
2519- tx .add_snapshot (snapshot = snapshot )
2520- tx .set_ref_snapshot (
2500+ if self ._transaction is not None :
2501+ self . _transaction .add_snapshot (snapshot = snapshot )
2502+ self . _transaction .set_ref_snapshot (
25212503 snapshot_id = self ._snapshot_id , parent_snapshot_id = self ._parent_snapshot_id , ref_name = "main" , type = "branch"
25222504 )
2505+ else :
2506+ with self ._table .transaction () as tx :
2507+ tx .add_snapshot (snapshot = snapshot )
2508+ tx .set_ref_snapshot (
2509+ snapshot_id = self ._snapshot_id , parent_snapshot_id = self ._parent_snapshot_id , ref_name = "main" , type = "branch"
2510+ )
25232511
25242512 return snapshot
2513+
2514+
2515+ class FastAppendFiles (_MergingSnapshotProducer ):
2516+ def _existing_manifests (self ) -> List [ManifestFile ]:
2517+ """To determine if there are any existing manifest files.
2518+
2519+ A fast append will add another ManifestFile to the ManifestList.
2520+ All the existing manifest files are considered existing.
2521+ """
2522+ existing_manifests = []
2523+
2524+ if self ._parent_snapshot_id is not None :
2525+ previous_snapshot = self ._table .snapshot_by_id (self ._parent_snapshot_id )
2526+
2527+ if previous_snapshot is None :
2528+ raise ValueError (f"Snapshot could not be found: { self ._parent_snapshot_id } " )
2529+
2530+ for manifest in previous_snapshot .manifests (io = self ._table .io ):
2531+ if manifest .has_added_files () or manifest .has_existing_files () or manifest .added_snapshot_id == self ._snapshot_id :
2532+ existing_manifests .append (manifest )
2533+
2534+ return existing_manifests
2535+
2536+ def _deleted_entries (self ) -> List [ManifestEntry ]:
2537+ """To determine if we need to record any deleted manifest entries.
2538+
2539+ In case of an append, nothing is deleted.
2540+ """
2541+ return []
2542+
2543+
2544+ class OverwriteFiles (_MergingSnapshotProducer ):
2545+ def _existing_manifests (self ) -> List [ManifestFile ]:
2546+ """To determine if there are any existing manifest files.
2547+
2548+ In the of a full overwrite, all the previous manifests are
2549+ considered deleted.
2550+ """
2551+ return []
2552+
2553+ def _deleted_entries (self ) -> List [ManifestEntry ]:
2554+ """To determine if we need to record any deleted entries.
2555+
2556+ With a full overwrite all the entries are considered deleted.
2557+ With partial overwrites we have to use the predicate to evaluate
2558+ which entries are affected.
2559+ """
2560+ if self ._parent_snapshot_id is not None :
2561+ previous_snapshot = self ._table .snapshot_by_id (self ._parent_snapshot_id )
2562+ if previous_snapshot is None :
2563+ # This should never happen since you cannot overwrite an empty table
2564+ raise ValueError (f"Could not find the previous snapshot: { self ._parent_snapshot_id } " )
2565+
2566+ executor = ExecutorFactory .get_or_create ()
2567+
2568+ def _get_entries (manifest : ManifestFile ) -> List [ManifestEntry ]:
2569+ return [
2570+ ManifestEntry (
2571+ status = ManifestEntryStatus .DELETED ,
2572+ snapshot_id = entry .snapshot_id ,
2573+ data_sequence_number = entry .data_sequence_number ,
2574+ file_sequence_number = entry .file_sequence_number ,
2575+ data_file = entry .data_file ,
2576+ )
2577+ for entry in manifest .fetch_manifest_entry (self ._table .io , discard_deleted = True )
2578+ if entry .data_file .content == DataFileContent .DATA
2579+ ]
2580+
2581+ list_of_entries = executor .map (_get_entries , previous_snapshot .manifests (self ._table .io ))
2582+ return list (chain (* list_of_entries ))
2583+ else :
2584+ return []
2585+
2586+
2587+ class UpdateSnapshot :
2588+ _table : Table
2589+ _transaction : Optional [Transaction ]
2590+
2591+ def __init__ (self , table : Table , transaction : Optional [Transaction ] = None ) -> None :
2592+ self ._table = table
2593+ self ._transaction = transaction
2594+
2595+ def fast_append (self ) -> FastAppendFiles :
2596+ return FastAppendFiles (table = self ._table , operation = Operation .APPEND , transaction = self ._transaction )
2597+
2598+ def overwrite (self ) -> OverwriteFiles :
2599+ return OverwriteFiles (
2600+ table = self ._table ,
2601+ operation = Operation .OVERWRITE if self ._table .current_snapshot () is not None else Operation .APPEND ,
2602+ transaction = self ._transaction ,
2603+ )
0 commit comments