Skip to content

Commit a3292f6

Browse files
committed
Build: Enhance type annotations in ManageSnapshots tests for better clarity and type safety
1 parent 1af9a94 commit a3292f6

File tree

1 file changed

+36
-42
lines changed

1 file changed

+36
-42
lines changed

tests/table/test_manage_snapshots_thread_safety.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_manage_snapshots_thread_safety_fix() -> None:
6464

6565
def test_manage_snapshots_concurrent_operations() -> None:
6666
"""Test concurrent operations with separate ManageSnapshots instances."""
67-
results: Dict[str, tuple] = {"manage1_updates": (), "manage2_updates": ()}
67+
results: Dict[str, tuple[Any, ...]] = {"manage1_updates": (), "manage2_updates": ()}
6868

6969
def worker1() -> None:
7070
transaction1 = Mock()
@@ -114,31 +114,31 @@ def test_manage_snapshots_concurrent_different_tables() -> None:
114114
table2.metadata.table_uuid = uuid4()
115115

116116
# Track calls to each table's manage operations
117-
table1_operations = []
118-
table2_operations = []
117+
table1_operations: List[Dict[str, Any]] = []
118+
table2_operations: List[Dict[str, Any]] = []
119119

120-
def create_table1_manage_mock():
120+
def create_table1_manage_mock() -> ManageSnapshots:
121121
transaction_mock = Mock()
122-
123-
def set_ref_snapshot_side_effect(**kwargs):
122+
123+
def set_ref_snapshot_side_effect(**kwargs: Any) -> tuple[tuple[str, ...], tuple[str, ...]]:
124124
table1_operations.append(kwargs)
125125
return (("table1_update",), ("table1_req",))
126-
126+
127127
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_snapshot_side_effect)
128128
manage_mock = ManageSnapshots(transaction_mock)
129-
manage_mock.commit = Mock(return_value=None)
129+
manage_mock.commit = Mock(return_value=None) # type: ignore[method-assign]
130130
return manage_mock
131131

132-
def create_table2_manage_mock():
132+
def create_table2_manage_mock() -> ManageSnapshots:
133133
transaction_mock = Mock()
134-
135-
def set_ref_snapshot_side_effect(**kwargs):
134+
135+
def set_ref_snapshot_side_effect(**kwargs: Any) -> tuple[tuple[str, ...], tuple[str, ...]]:
136136
table2_operations.append(kwargs)
137137
return (("table2_update",), ("table2_req",))
138-
138+
139139
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_snapshot_side_effect)
140140
manage_mock = ManageSnapshots(transaction_mock)
141-
manage_mock.commit = Mock(return_value=None)
141+
manage_mock.commit = Mock(return_value=None) # type: ignore[method-assign]
142142
return manage_mock
143143

144144
table1.manage_snapshots = Mock(side_effect=create_table1_manage_mock)
@@ -148,9 +148,7 @@ def set_ref_snapshot_side_effect(**kwargs):
148148
table1_snapshot_id = 1001
149149
table2_snapshot_id = 2001
150150

151-
def manage_table_snapshots(
152-
table_obj: Any, table_name: str, snapshot_id: int, tag_name: str, results: Dict[str, Any]
153-
) -> None:
151+
def manage_table_snapshots(table_obj: Any, table_name: str, snapshot_id: int, tag_name: str, results: Dict[str, Any]) -> None:
154152
"""Manage snapshots for a specific table."""
155153
try:
156154
# Create tag operation (as in real usage)
@@ -168,12 +166,8 @@ def manage_table_snapshots(
168166
results2: Dict[str, Any] = {}
169167

170168
# Create threads to manage snapshots for different tables concurrently
171-
thread1 = threading.Thread(
172-
target=manage_table_snapshots, args=(table1, "table1", table1_snapshot_id, "tag1", results1)
173-
)
174-
thread2 = threading.Thread(
175-
target=manage_table_snapshots, args=(table2, "table2", table2_snapshot_id, "tag2", results2)
176-
)
169+
thread1 = threading.Thread(target=manage_table_snapshots, args=(table1, "table1", table1_snapshot_id, "tag1", results1))
170+
thread2 = threading.Thread(target=manage_table_snapshots, args=(table2, "table2", table2_snapshot_id, "tag2", results2))
177171

178172
# Start threads concurrently
179173
thread1.start()
@@ -218,37 +212,37 @@ def test_manage_snapshots_cross_table_isolation() -> None:
218212
table2.metadata.table_uuid = uuid.uuid4()
219213

220214
# Track which operations each table's manage operation receives
221-
table1_manage_calls = []
222-
table2_manage_calls = []
215+
table1_manage_calls: List[Dict[str, Any]] = []
216+
table2_manage_calls: List[Dict[str, Any]] = []
223217

224-
def mock_table1_manage():
218+
def mock_table1_manage() -> ManageSnapshots:
225219
transaction_mock = Mock()
226-
227-
def set_ref_side_effect(**kwargs):
220+
221+
def set_ref_side_effect(**kwargs: Any) -> tuple[tuple[str, ...], tuple[str, ...]]:
228222
table1_manage_calls.append(kwargs)
229223
return (("update1",), ("req1",))
230-
224+
231225
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_side_effect)
232226
manage_mock = ManageSnapshots(transaction_mock)
233-
manage_mock.commit = Mock(return_value=None)
227+
manage_mock.commit = Mock(return_value=None) # type: ignore[method-assign]
234228
return manage_mock
235229

236-
def mock_table2_manage():
230+
def mock_table2_manage() -> ManageSnapshots:
237231
transaction_mock = Mock()
238-
239-
def set_ref_side_effect(**kwargs):
232+
233+
def set_ref_side_effect(**kwargs: Any) -> tuple[tuple[str, ...], tuple[str, ...]]:
240234
table2_manage_calls.append(kwargs)
241235
return (("update2",), ("req2",))
242-
236+
243237
transaction_mock._set_ref_snapshot = Mock(side_effect=set_ref_side_effect)
244238
manage_mock = ManageSnapshots(transaction_mock)
245-
manage_mock.commit = Mock(return_value=None)
239+
manage_mock.commit = Mock(return_value=None) # type: ignore[method-assign]
246240
return manage_mock
247241

248242
table1.manage_snapshots = Mock(side_effect=mock_table1_manage)
249243
table2.manage_snapshots = Mock(side_effect=mock_table2_manage)
250244

251-
def manage_from_table(table: Any, table_name: str, operations: List[Dict], results: Dict[str, Any]) -> None:
245+
def manage_from_table(table: Any, table_name: str, operations: List[Dict[str, Any]], results: Dict[str, Any]) -> None:
252246
"""Perform multiple manage operations on a specific table."""
253247
try:
254248
for op in operations:
@@ -257,7 +251,7 @@ def manage_from_table(table: Any, table_name: str, operations: List[Dict], resul
257251
manager.create_tag(snapshot_id=op["snapshot_id"], tag_name=op["name"]).commit()
258252
elif op["type"] == "branch":
259253
manager.create_branch(snapshot_id=op["snapshot_id"], branch_name=op["name"]).commit()
260-
254+
261255
results["success"] = True
262256
results["operations"] = operations
263257
except Exception as e:
@@ -292,15 +286,15 @@ def manage_from_table(table: Any, table_name: str, operations: List[Dict], resul
292286
# Table1 should only see table1 operations
293287
table1_snapshot_ids = [call["snapshot_id"] for call in table1_manage_calls]
294288
expected_table1_ids = [1001, 1002]
295-
289+
296290
assert table1_snapshot_ids == expected_table1_ids, (
297291
f"Table1 received unexpected snapshot IDs: {table1_snapshot_ids} (expected {expected_table1_ids})"
298292
)
299293

300294
# Table2 should only see table2 operations
301295
table2_snapshot_ids = [call["snapshot_id"] for call in table2_manage_calls]
302296
expected_table2_ids = [2001, 2002]
303-
297+
304298
assert table2_snapshot_ids == expected_table2_ids, (
305299
f"Table2 received unexpected snapshot IDs: {table2_snapshot_ids} (expected {expected_table2_ids})"
306300
)
@@ -316,17 +310,17 @@ def manage_from_table(table: Any, table_name: str, operations: List[Dict], resul
316310

317311
def test_manage_snapshots_concurrent_same_table_different_operations() -> None:
318312
"""Test that concurrent ManageSnapshots operations work correctly."""
319-
313+
320314
# Mock current snapshot ID
321315
current_snapshot_id = 12345
322316

323317
# Create mock transactions that return the expected format
324-
def create_mock_transaction():
318+
def create_mock_transaction() -> Mock:
325319
transaction_mock = Mock()
326320
transaction_mock._set_ref_snapshot = Mock(return_value=(("update",), ("req",)))
327321
return transaction_mock
328322

329-
def manage_snapshots_thread_func(operations: List[Dict], results: Dict[str, Any]) -> None:
323+
def manage_snapshots_thread_func(operations: List[Dict[str, Any]], results: Dict[str, Any]) -> None:
330324
"""Function to run in a thread that performs manage snapshot operations and captures results."""
331325
try:
332326
for op in operations:
@@ -365,7 +359,7 @@ def manage_snapshots_thread_func(operations: List[Dict], results: Dict[str, Any]
365359
# Assert that both operations succeeded
366360
assert results1.get("success", False), f"Thread 1 management failed: {results1.get('error', 'Unknown error')}"
367361
assert results2.get("success", False), f"Thread 2 management failed: {results2.get('error', 'Unknown error')}"
368-
362+
369363
# Verify that each thread has its own isolated state
370364
assert results1["updates"] == ("update",), f"Thread 1 should have ('update',), got {results1['updates']}"
371365
assert results2["updates"] == ("update",), f"Thread 2 should have ('update',), got {results2['updates']}"

0 commit comments

Comments
 (0)