This repository was archived by the owner on Mar 31, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathtransaction_helper.py
More file actions
294 lines (257 loc) · 11.6 KB
/
transaction_helper.py
File metadata and controls
294 lines (257 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Any, Dict
from google.api_core.exceptions import Aborted
import time
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1._helpers import _get_retry_delay
if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
from google.cloud.spanner_dbapi.checksum import ResultsChecksum, _compare_checksums
MAX_INTERNAL_RETRIES = 50
RETRY_ABORTED_ERROR = "The transaction was aborted and could not be retried due to a concurrent modification."
class TransactionRetryHelper:
def __init__(self, connection: "Connection"):
"""Helper class used in retrying the transaction when aborted This will
maintain all the statements executed on original transaction and replay
them again in the retried transaction.
:type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection`
:param connection: A DB-API connection to Google Cloud Spanner.
"""
self._connection = connection
# list of all statements in the same order as executed in original
# transaction along with their results
self._statement_result_details_list: List[StatementDetails] = []
# Map of last StatementDetails that was added to a particular cursor
self._last_statement_details_per_cursor: Dict[Cursor, StatementDetails] = {}
# 1-1 map from original cursor object on which transaction ran to the
# new cursor object used in the retry
self._cursor_map: Dict[Cursor, Cursor] = {}
def _set_connection_for_retry(self):
self._connection._spanner_transaction_started = False
self._connection._transaction_begin_marked = False
self._connection._batch_mode = BatchMode.NONE
def reset(self):
"""
Resets the state of the class when the ongoing transaction is committed
or aborted
"""
self._statement_result_details_list = []
self._last_statement_details_per_cursor = {}
self._cursor_map = {}
def add_fetch_statement_for_retry(
self, cursor, result_rows, exception, is_fetch_all
):
"""
StatementDetails to be added to _statement_result_details_list whenever fetchone, fetchmany or
fetchall method is called on the cursor.
If fetchone is consecutively called n times then it is stored as fetchmany with size as n.
Same for fetchmany, so consecutive fetchone and fetchmany statements are stored as one
fetchmany statement in _statement_result_details_list with size param appropriately set
:param cursor: original Cursor object on which statement executed in the transaction
:param result_rows: All the rows from the resultSet from fetch statement execution
:param exception: Not none in case non-aborted exception is thrown on the original
statement execution
:param is_fetch_all: True in case of fetchall statement execution
"""
if not self._connection._client_transaction_started:
return
last_statement_result_details = self._last_statement_details_per_cursor.get(
cursor
)
if (
last_statement_result_details is not None
and last_statement_result_details.statement_type
== CursorStatementType.FETCH_MANY
):
if exception is not None:
last_statement_result_details.result_type = ResultType.EXCEPTION
last_statement_result_details.result_details = exception
else:
for row in result_rows:
last_statement_result_details.result_details.consume_result(row)
last_statement_result_details.size += len(result_rows)
else:
result_details = _get_statement_result_checksum(result_rows)
if is_fetch_all:
statement_type = CursorStatementType.FETCH_ALL
size = None
else:
statement_type = CursorStatementType.FETCH_MANY
size = len(result_rows)
last_statement_result_details = FetchStatement(
cursor=cursor,
statement_type=statement_type,
result_type=ResultType.CHECKSUM,
result_details=result_details,
size=size,
)
self._last_statement_details_per_cursor[
cursor
] = last_statement_result_details
self._statement_result_details_list.append(last_statement_result_details)
def add_execute_statement_for_retry(
self, cursor, sql, args, exception, is_execute_many
):
"""
StatementDetails to be added to _statement_result_details_list whenever execute or
executemany method is called on the cursor.
:param cursor: original Cursor object on which statement executed in the transaction
:param sql: Input param of the execute/executemany method
:param args: Input param of the execute/executemany method
:param exception: Not none in case non-aborted exception is thrown on the original
statement execution
:param is_execute_many: True in case of executemany statement execution
"""
if not self._connection._client_transaction_started:
return
statement_type = CursorStatementType.EXECUTE
if is_execute_many:
statement_type = CursorStatementType.EXECUTE_MANY
result_type = ResultType.NONE
result_details = None
if exception is not None:
result_type = ResultType.EXCEPTION
result_details = exception
elif cursor._batch_dml_rows_count is not None:
result_type = ResultType.BATCH_DML_ROWS_COUNT
result_details = cursor._batch_dml_rows_count
elif cursor._row_count is not None:
result_type = ResultType.ROW_COUNT
result_details = cursor.rowcount
last_statement_result_details = ExecuteStatement(
cursor=cursor,
statement_type=statement_type,
sql=sql,
args=args,
result_type=result_type,
result_details=result_details,
)
self._last_statement_details_per_cursor[cursor] = last_statement_result_details
self._statement_result_details_list.append(last_statement_result_details)
def retry_transaction(self, default_retry_delay=None):
"""Retry the aborted transaction.
All the statements executed in the original transaction
will be re-executed in new one. Results checksums of the
original statements and the retried ones will be compared.
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
If results checksum of the retried statement is
not equal to the checksum of the original one.
"""
attempt = 0
while True:
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise
self._set_connection_for_retry()
try:
for statement_result_details in self._statement_result_details_list:
if statement_result_details.cursor in self._cursor_map:
cursor = self._cursor_map.get(statement_result_details.cursor)
else:
cursor = self._connection.cursor()
cursor._in_retry_mode = True
self._cursor_map[statement_result_details.cursor] = cursor
try:
_handle_statement(statement_result_details, cursor)
except Aborted:
raise
except RetryAborted:
raise
except Exception as ex:
if (
type(statement_result_details.result_details)
is not type(ex)
or ex.args != statement_result_details.result_details.args
):
raise RetryAborted(RETRY_ABORTED_ERROR, ex)
return
except Aborted as ex:
delay = _get_retry_delay(
ex.errors[0], attempt, default_retry_delay=default_retry_delay
)
if delay:
time.sleep(delay)
def _handle_statement(statement_result_details, cursor):
statement_type = statement_result_details.statement_type
if _is_execute_type_statement(statement_type):
if statement_type == CursorStatementType.EXECUTE:
cursor.execute(statement_result_details.sql, statement_result_details.args)
if (
statement_result_details.result_type == ResultType.ROW_COUNT
and statement_result_details.result_details != cursor.rowcount
):
raise RetryAborted(RETRY_ABORTED_ERROR)
else:
cursor.executemany(
statement_result_details.sql, statement_result_details.args
)
if (
statement_result_details.result_type == ResultType.BATCH_DML_ROWS_COUNT
and statement_result_details.result_details != cursor._batch_dml_rows_count
):
raise RetryAborted(RETRY_ABORTED_ERROR)
else:
if statement_type == CursorStatementType.FETCH_ALL:
res = cursor.fetchall()
else:
res = cursor.fetchmany(statement_result_details.size)
checksum = _get_statement_result_checksum(res)
_compare_checksums(checksum, statement_result_details.result_details)
if statement_result_details.result_type == ResultType.EXCEPTION:
raise RetryAborted(RETRY_ABORTED_ERROR)
def _is_execute_type_statement(statement_type):
return statement_type in (
CursorStatementType.EXECUTE,
CursorStatementType.EXECUTE_MANY,
)
def _get_statement_result_checksum(res_iter):
retried_checksum = ResultsChecksum()
for res in res_iter:
retried_checksum.consume_result(res)
return retried_checksum
class CursorStatementType(Enum):
EXECUTE = 1
EXECUTE_MANY = 2
FETCH_ONE = 3
FETCH_ALL = 4
FETCH_MANY = 5
class ResultType(Enum):
# checksum of ResultSet in case of fetch call on query statement
CHECKSUM = 1
# None in case of execute call on query statement
NONE = 2
# Exception details in case of any statement execution throws exception
EXCEPTION = 3
# Total rows updated in case of execute call on DML statement
ROW_COUNT = 4
# Total rows updated in case of Batch DML statement execution
BATCH_DML_ROWS_COUNT = 5
@dataclass
class StatementDetails:
statement_type: CursorStatementType
# The cursor object on which this statement was executed
cursor: "Cursor"
result_type: ResultType
result_details: Any
@dataclass
class ExecuteStatement(StatementDetails):
sql: str
args: Any = None
@dataclass
class FetchStatement(StatementDetails):
size: int = None