Skip to content

Commit 1af9a94

Browse files
committed
Build: Refactor ManageSnapshots to ensure thread safety and add unit tests for concurrent operations
1 parent 52d810e commit 1af9a94

File tree

2 files changed

+375
-2
lines changed

2 files changed

+375
-2
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,10 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
810810
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
811811
"""
812812

813-
_updates: Tuple[TableUpdate, ...] = ()
814-
_requirements: Tuple[TableRequirement, ...] = ()
813+
def __init__(self, transaction: Transaction) -> None:
814+
super().__init__(transaction)
815+
self._updates: Tuple[TableUpdate, ...] = ()
816+
self._requirements: Tuple[TableRequirement, ...] = ()
815817

816818
def _commit(self) -> UpdatesAndRequirements:
817819
"""Apply the pending changes and commit."""
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import threading
18+
import uuid
19+
from typing import Any, Dict, List
20+
from unittest.mock import Mock
21+
from uuid import uuid4
22+
23+
from pyiceberg.table.update.snapshot import ManageSnapshots
24+
25+
26+
def test_manage_snapshots_thread_safety_fix() -> None:
27+
"""Test that ManageSnapshots instances have isolated state."""
28+
# Create two mock transactions (representing different tables)
29+
transaction1 = Mock()
30+
transaction2 = Mock()
31+
32+
# Create two ManageSnapshots instances
33+
manage1 = ManageSnapshots(transaction1)
34+
manage2 = ManageSnapshots(transaction2)
35+
36+
# Note: Empty tuples () are singletons in Python, so we need to test with actual content
37+
# Mock the transaction method that would normally be called
38+
transaction1._set_ref_snapshot = Mock(return_value=(("update1",), ("req1",)))
39+
transaction2._set_ref_snapshot = Mock(return_value=(("update2",), ("req2",)))
40+
41+
# Make different changes to each instance
42+
manage1.create_tag(snapshot_id=1001, tag_name="tag1")
43+
manage2.create_tag(snapshot_id=2001, tag_name="tag2")
44+
45+
# NOW verify they have separate update tuples (this proves the fix works!)
46+
# Before fix: both would have the same content due to shared class attributes
47+
# After fix: they should have different content because they have separate instance attributes
48+
assert id(manage1._updates) != id(manage2._updates), (
49+
"ManageSnapshots instances are sharing the same _updates tuple - thread safety bug still exists"
50+
)
51+
52+
assert id(manage1._requirements) != id(manage2._requirements), (
53+
"ManageSnapshots instances are sharing the same _requirements tuple - thread safety bug still exists"
54+
)
55+
56+
# Verify no cross-contamination of updates
57+
assert manage1._updates != manage2._updates, "Updates are contaminated between instances"
58+
assert manage1._requirements != manage2._requirements, "Requirements are contaminated between instances"
59+
60+
# Verify each instance has its expected content
61+
assert manage1._updates == ("update1",), f"manage1 should have ('update1',), got {manage1._updates}"
62+
assert manage2._updates == ("update2",), f"manage2 should have ('update2',), got {manage2._updates}"
63+
64+
65+
def test_manage_snapshots_concurrent_operations() -> None:
66+
"""Test concurrent operations with separate ManageSnapshots instances."""
67+
results: Dict[str, tuple] = {"manage1_updates": (), "manage2_updates": ()}
68+
69+
def worker1() -> None:
70+
transaction1 = Mock()
71+
transaction1._set_ref_snapshot = Mock(return_value=(("update1",), ("req1",)))
72+
manage1 = ManageSnapshots(transaction1)
73+
manage1.create_tag(snapshot_id=1001, tag_name="tag1")
74+
results["manage1_updates"] = manage1._updates
75+
76+
def worker2() -> None:
77+
transaction2 = Mock()
78+
transaction2._set_ref_snapshot = Mock(return_value=(("update2",), ("req2",)))
79+
manage2 = ManageSnapshots(transaction2)
80+
manage2.create_tag(snapshot_id=2001, tag_name="tag2")
81+
results["manage2_updates"] = manage2._updates
82+
83+
# Run both workers concurrently
84+
thread1 = threading.Thread(target=worker1)
85+
thread2 = threading.Thread(target=worker2)
86+
87+
thread1.start()
88+
thread2.start()
89+
90+
thread1.join()
91+
thread2.join()
92+
93+
# Check for cross-contamination
94+
expected_1 = ("update1",)
95+
expected_2 = ("update2",)
96+
97+
assert results["manage1_updates"] == expected_1, "Worker 1 updates contaminated"
98+
assert results["manage2_updates"] == expected_2, "Worker 2 updates contaminated"
99+
100+
101+
def test_manage_snapshots_concurrent_different_tables() -> None:
102+
"""Test that concurrent ManageSnapshots operations on DIFFERENT tables work correctly.
103+
104+
This test validates the thread safety fix by ensuring that concurrent
105+
operations on different table objects properly isolate their updates and requirements.
106+
"""
107+
# Create two mock tables with different configurations
108+
table1 = Mock()
109+
table1.metadata = Mock()
110+
table1.metadata.table_uuid = uuid4()
111+
112+
table2 = Mock()
113+
table2.metadata = Mock()
114+
table2.metadata.table_uuid = uuid4()
115+
116+
# Track calls to each table's manage operations
117+
table1_operations = []
118+
table2_operations = []
119+
120+
def create_table1_manage_mock():
121+
transaction_mock = Mock()
122+
123+
def set_ref_snapshot_side_effect(**kwargs):
124+
table1_operations.append(kwargs)
125+
return (("table1_update",), ("table1_req",))
126+
127+
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_snapshot_side_effect)
128+
manage_mock = ManageSnapshots(transaction_mock)
129+
manage_mock.commit = Mock(return_value=None)
130+
return manage_mock
131+
132+
def create_table2_manage_mock():
133+
transaction_mock = Mock()
134+
135+
def set_ref_snapshot_side_effect(**kwargs):
136+
table2_operations.append(kwargs)
137+
return (("table2_update",), ("table2_req",))
138+
139+
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_snapshot_side_effect)
140+
manage_mock = ManageSnapshots(transaction_mock)
141+
manage_mock.commit = Mock(return_value=None)
142+
return manage_mock
143+
144+
table1.manage_snapshots = Mock(side_effect=create_table1_manage_mock)
145+
table2.manage_snapshots = Mock(side_effect=create_table2_manage_mock)
146+
147+
# Define different snapshot IDs and operations for each table
148+
table1_snapshot_id = 1001
149+
table2_snapshot_id = 2001
150+
151+
def manage_table_snapshots(
152+
table_obj: Any, table_name: str, snapshot_id: int, tag_name: str, results: Dict[str, Any]
153+
) -> None:
154+
"""Manage snapshots for a specific table."""
155+
try:
156+
# Create tag operation (as in real usage)
157+
table_obj.manage_snapshots().create_tag(snapshot_id=snapshot_id, tag_name=tag_name).commit()
158+
159+
results["success"] = True
160+
results["snapshot_id"] = snapshot_id
161+
results["tag_name"] = tag_name
162+
163+
except Exception as e:
164+
results["success"] = False
165+
results["error"] = str(e)
166+
167+
results1: Dict[str, Any] = {}
168+
results2: Dict[str, Any] = {}
169+
170+
# Create threads to manage snapshots for different tables concurrently
171+
thread1 = threading.Thread(
172+
target=manage_table_snapshots, args=(table1, "table1", table1_snapshot_id, "tag1", results1)
173+
)
174+
thread2 = threading.Thread(
175+
target=manage_table_snapshots, args=(table2, "table2", table2_snapshot_id, "tag2", results2)
176+
)
177+
178+
# Start threads concurrently
179+
thread1.start()
180+
thread2.start()
181+
182+
# Wait for completion
183+
thread1.join()
184+
thread2.join()
185+
186+
# Check results - both should succeed if thread safety is correct
187+
assert results1.get("success", False), f"Table1 management failed: {results1.get('error', 'Unknown error')}"
188+
assert results2.get("success", False), f"Table2 management failed: {results2.get('error', 'Unknown error')}"
189+
190+
# CRITICAL: Verify that each table only received its own operations
191+
# This is the key test - if the bug exists, operations will cross-contaminate
192+
assert len(table1_operations) == 1, f"Table1 should have 1 operation, got {len(table1_operations)}"
193+
assert len(table2_operations) == 1, f"Table2 should have 1 operation, got {len(table2_operations)}"
194+
195+
# Verify the operations contain the correct snapshot IDs
196+
assert table1_operations[0]["snapshot_id"] == table1_snapshot_id, "Table1 received wrong snapshot ID"
197+
assert table2_operations[0]["snapshot_id"] == table2_snapshot_id, "Table2 received wrong snapshot ID"
198+
199+
# Verify tag names are correct
200+
assert table1_operations[0]["ref_name"] == "tag1", "Table1 received wrong tag name"
201+
assert table2_operations[0]["ref_name"] == "tag2", "Table2 received wrong tag name"
202+
203+
204+
def test_manage_snapshots_cross_table_isolation() -> None:
205+
"""Test that verifies operations don't get mixed up between different tables.
206+
207+
This test validates the fix by ensuring that concurrent
208+
operations on different table objects properly isolate their updates and requirements.
209+
"""
210+
211+
# Create two mock table objects to simulate real usage
212+
table1 = Mock()
213+
table1.metadata = Mock()
214+
table1.metadata.table_uuid = uuid.uuid4()
215+
216+
table2 = Mock()
217+
table2.metadata = Mock()
218+
table2.metadata.table_uuid = uuid.uuid4()
219+
220+
# Track which operations each table's manage operation receives
221+
table1_manage_calls = []
222+
table2_manage_calls = []
223+
224+
def mock_table1_manage():
225+
transaction_mock = Mock()
226+
227+
def set_ref_side_effect(**kwargs):
228+
table1_manage_calls.append(kwargs)
229+
return (("update1",), ("req1",))
230+
231+
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_side_effect)
232+
manage_mock = ManageSnapshots(transaction_mock)
233+
manage_mock.commit = Mock(return_value=None)
234+
return manage_mock
235+
236+
def mock_table2_manage():
237+
transaction_mock = Mock()
238+
239+
def set_ref_side_effect(**kwargs):
240+
table2_manage_calls.append(kwargs)
241+
return (("update2",), ("req2",))
242+
243+
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_side_effect)
244+
manage_mock = ManageSnapshots(transaction_mock)
245+
manage_mock.commit = Mock(return_value=None)
246+
return manage_mock
247+
248+
table1.manage_snapshots = Mock(side_effect=mock_table1_manage)
249+
table2.manage_snapshots = Mock(side_effect=mock_table2_manage)
250+
251+
def manage_from_table(table: Any, table_name: str, operations: List[Dict], results: Dict[str, Any]) -> None:
252+
"""Perform multiple manage operations on a specific table."""
253+
try:
254+
for op in operations:
255+
manager = table.manage_snapshots()
256+
if op["type"] == "tag":
257+
manager.create_tag(snapshot_id=op["snapshot_id"], tag_name=op["name"]).commit()
258+
elif op["type"] == "branch":
259+
manager.create_branch(snapshot_id=op["snapshot_id"], branch_name=op["name"]).commit()
260+
261+
results["success"] = True
262+
results["operations"] = operations
263+
except Exception as e:
264+
results["success"] = False
265+
results["error"] = str(e)
266+
267+
# Prepare different operations for each table
268+
table1_operations = [
269+
{"type": "tag", "snapshot_id": 1001, "name": "tag1"},
270+
{"type": "branch", "snapshot_id": 1002, "name": "branch1"},
271+
]
272+
table2_operations = [
273+
{"type": "tag", "snapshot_id": 2001, "name": "tag2"},
274+
{"type": "branch", "snapshot_id": 2002, "name": "branch2"},
275+
]
276+
277+
results1: Dict[str, Any] = {}
278+
results2: Dict[str, Any] = {}
279+
280+
# Run concurrent management operations
281+
thread1 = threading.Thread(target=manage_from_table, args=(table1, "table1", table1_operations, results1))
282+
thread2 = threading.Thread(target=manage_from_table, args=(table2, "table2", table2_operations, results2))
283+
284+
thread1.start()
285+
thread2.start()
286+
thread1.join()
287+
thread2.join()
288+
289+
# CRITICAL ASSERTION: Each table should only receive its own operations
290+
# If this fails, it means the thread safety bug exists
291+
292+
# Table1 should only see table1 operations
293+
table1_snapshot_ids = [call["snapshot_id"] for call in table1_manage_calls]
294+
expected_table1_ids = [1001, 1002]
295+
296+
assert table1_snapshot_ids == expected_table1_ids, (
297+
f"Table1 received unexpected snapshot IDs: {table1_snapshot_ids} (expected {expected_table1_ids})"
298+
)
299+
300+
# Table2 should only see table2 operations
301+
table2_snapshot_ids = [call["snapshot_id"] for call in table2_manage_calls]
302+
expected_table2_ids = [2001, 2002]
303+
304+
assert table2_snapshot_ids == expected_table2_ids, (
305+
f"Table2 received unexpected snapshot IDs: {table2_snapshot_ids} (expected {expected_table2_ids})"
306+
)
307+
308+
# Verify no cross-contamination
309+
table1_received_table2_ids = [sid for sid in table1_snapshot_ids if sid in expected_table2_ids]
310+
table2_received_table1_ids = [sid for sid in table2_snapshot_ids if sid in expected_table1_ids]
311+
312+
assert len(table1_received_table2_ids) == 0, f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}"
313+
314+
assert len(table2_received_table1_ids) == 0, f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}"
315+
316+
317+
def test_manage_snapshots_concurrent_same_table_different_operations() -> None:
318+
"""Test that concurrent ManageSnapshots operations work correctly."""
319+
320+
# Mock current snapshot ID
321+
current_snapshot_id = 12345
322+
323+
# Create mock transactions that return the expected format
324+
def create_mock_transaction():
325+
transaction_mock = Mock()
326+
transaction_mock._set_ref_snapshot = Mock(return_value=(("update",), ("req",)))
327+
return transaction_mock
328+
329+
def manage_snapshots_thread_func(operations: List[Dict], results: Dict[str, Any]) -> None:
330+
"""Function to run in a thread that performs manage snapshot operations and captures results."""
331+
try:
332+
for op in operations:
333+
transaction = create_mock_transaction()
334+
manager = ManageSnapshots(transaction)
335+
if op["type"] == "tag":
336+
manager.create_tag(snapshot_id=op["snapshot_id"], tag_name=op["name"])
337+
elif op["type"] == "branch":
338+
manager.create_branch(snapshot_id=op["snapshot_id"], branch_name=op["name"])
339+
# Verify each manager has its own state
340+
results["updates"] = manager._updates
341+
results["requirements"] = manager._requirements
342+
results["success"] = True
343+
except Exception as e:
344+
results["success"] = False
345+
results["error"] = str(e)
346+
347+
# Define different operations for each thread
348+
operations1 = [{"type": "tag", "snapshot_id": current_snapshot_id, "name": "test_tag_1"}]
349+
operations2 = [{"type": "branch", "snapshot_id": current_snapshot_id, "name": "test_branch_2"}]
350+
351+
# Prepare result dictionaries to capture thread outcomes
352+
results1: Dict[str, Any] = {}
353+
results2: Dict[str, Any] = {}
354+
355+
# Create threads to perform different operations concurrently
356+
thread1 = threading.Thread(target=manage_snapshots_thread_func, args=(operations1, results1))
357+
thread2 = threading.Thread(target=manage_snapshots_thread_func, args=(operations2, results2))
358+
359+
# Start and join threads
360+
thread1.start()
361+
thread2.start()
362+
thread1.join()
363+
thread2.join()
364+
365+
# Assert that both operations succeeded
366+
assert results1.get("success", False), f"Thread 1 management failed: {results1.get('error', 'Unknown error')}"
367+
assert results2.get("success", False), f"Thread 2 management failed: {results2.get('error', 'Unknown error')}"
368+
369+
# Verify that each thread has its own isolated state
370+
assert results1["updates"] == ("update",), f"Thread 1 should have ('update',), got {results1['updates']}"
371+
assert results2["updates"] == ("update",), f"Thread 2 should have ('update',), got {results2['updates']}"

0 commit comments

Comments
 (0)