1414
1515"""Context manager for Cloud Spanner batched writes."""
1616import functools
17+ from typing import List , Optional
1718
18- from google .cloud .spanner_v1 import CommitRequest
19+ from google .cloud .spanner_v1 import CommitRequest , CommitResponse
1920from google .cloud .spanner_v1 import Mutation
2021from google .cloud .spanner_v1 import TransactionOptions
2122from google .cloud .spanner_v1 import BatchWriteRequest
@@ -47,22 +48,15 @@ class _BatchBase(_SessionWrapper):
4748 :param session: the session used to perform the commit
4849 """
4950
50- transaction_tag = None
51- _read_only = False
52-
5351 def __init__ (self , session ):
5452 super (_BatchBase , self ).__init__ (session )
55- self ._mutations = []
56-
57- def _check_state (self ):
58- """Helper for :meth:`commit` et al.
5953
60- Subclasses must override
54+ self ._mutations : List [Mutation ] = []
55+ self .transaction_tag : Optional [str ] = None
6156
62- :raises: :exc:`ValueError` if the object's state is invalid for making
63- API requests.
64- """
65- raise NotImplementedError
57+ self .committed = None
58+ """Timestamp at which the batch was successfully committed."""
59+ self .commit_stats : Optional [CommitResponse .CommitStats ] = None
6660
6761 def insert (self , table , columns , values ):
6862 """Insert one or more new table rows.
@@ -148,29 +142,15 @@ def delete(self, table, keyset):
148142class Batch (_BatchBase ):
149143 """Accumulate mutations for transmission during :meth:`commit`."""
150144
151- committed = None
152- commit_stats = None
153- """Timestamp at which the batch was successfully committed."""
154-
155- def _check_state (self ):
156- """Helper for :meth:`commit` et al.
157-
158- Subclasses must override
159-
160- :raises: :exc:`ValueError` if the object's state is invalid for making
161- API requests.
162- """
163- if self .committed is not None :
164- raise ValueError ("Batch already committed" )
165-
166145 def commit (
167146 self ,
168147 return_commit_stats = False ,
169148 request_options = None ,
170149 max_commit_delay = None ,
171150 exclude_txn_from_change_streams = False ,
172151 isolation_level = TransactionOptions .IsolationLevel .ISOLATION_LEVEL_UNSPECIFIED ,
173- ** kwargs ,
152+ timeout_secs = DEFAULT_RETRY_TIMEOUT_SECS ,
153+ default_retry_delay = None ,
174154 ):
175155 """Commit mutations to the database.
176156
@@ -202,12 +182,26 @@ def commit(
202182 :param isolation_level:
203183 (Optional) Sets isolation level for the transaction.
204184
185+ :type timeout_secs: int
186+ :param timeout_secs: (Optional) The maximum time in seconds to wait for the commit to complete.
187+
188+ :type default_retry_delay: int
189+ :param timeout_secs: (Optional) The default time in seconds to wait before re-trying the commit..
190+
205191 :rtype: datetime
206192 :returns: timestamp of the committed changes.
193+
194+ :raises: ValueError: if the transaction is not ready to commit.
207195 """
208- self ._check_state ()
209- database = self ._session ._database
196+
197+ if self .committed is not None :
198+ raise ValueError ("Transaction already committed." )
199+
200+ mutations = self ._mutations
201+ session = self ._session
202+ database = session ._database
210203 api = database .spanner_api
204+
211205 metadata = _metadata_with_prefix (database .name )
212206 if database ._route_to_leader_enabled :
213207 metadata .append (
@@ -223,7 +217,6 @@ def commit(
223217 database .default_transaction_options .default_read_write_transaction_options ,
224218 txn_options ,
225219 )
226- trace_attributes = {"num_mutations" : len (self ._mutations )}
227220
228221 if request_options is None :
229222 request_options = RequestOptions ()
@@ -234,27 +227,26 @@ def commit(
234227 # Request tags are not supported for commit requests.
235228 request_options .request_tag = None
236229
237- request = CommitRequest (
238- session = self ._session .name ,
239- mutations = self ._mutations ,
240- single_use_transaction = txn_options ,
241- return_commit_stats = return_commit_stats ,
242- max_commit_delay = max_commit_delay ,
243- request_options = request_options ,
244- )
245- observability_options = getattr (database , "observability_options" , None )
246230 with trace_call (
247- f"CloudSpanner.{ type (self ).__name__ } .commit" ,
248- self . _session ,
249- trace_attributes ,
250- observability_options = observability_options ,
231+ name = f"CloudSpanner.{ type (self ).__name__ } .commit" ,
232+ session = session ,
233+ extra_attributes = { "num_mutations" : len ( mutations )} ,
234+ observability_options = getattr ( database , " observability_options" , None ) ,
251235 metadata = metadata ,
252236 ) as span , MetricsCapture ():
253237
254- def wrapped_method (* args , ** kwargs ):
255- method = functools .partial (
238+ def wrapped_method ():
239+ commit_request = CommitRequest (
240+ session = session .name ,
241+ mutations = mutations ,
242+ single_use_transaction = txn_options ,
243+ return_commit_stats = return_commit_stats ,
244+ max_commit_delay = max_commit_delay ,
245+ request_options = request_options ,
246+ )
247+ commit_method = functools .partial (
256248 api .commit ,
257- request = request ,
249+ request = commit_request ,
258250 metadata = database .metadata_with_request_id (
259251 # This code is retried due to ABORTED, hence nth_request
260252 # should be increased. attempt can only be increased if
@@ -265,24 +257,23 @@ def wrapped_method(*args, **kwargs):
265257 span ,
266258 ),
267259 )
268- return method ( * args , ** kwargs )
260+ return commit_method ( )
269261
270- deadline = time .time () + kwargs .get (
271- "timeout_secs" , DEFAULT_RETRY_TIMEOUT_SECS
272- )
273- default_retry_delay = kwargs .get ("default_retry_delay" , None )
274262 response = _retry_on_aborted_exception (
275263 wrapped_method ,
276- deadline = deadline ,
264+ deadline = time . time () + timeout_secs ,
277265 default_retry_delay = default_retry_delay ,
278266 )
267+
279268 self .committed = response .commit_timestamp
280269 self .commit_stats = response .commit_stats
270+
281271 return self .committed
282272
283273 def __enter__ (self ):
284274 """Begin ``with`` block."""
285- self ._check_state ()
275+ if self .committed is not None :
276+ raise ValueError ("Transaction already committed" )
286277
287278 return self
288279
@@ -317,20 +308,10 @@ class MutationGroups(_SessionWrapper):
317308 :param session: the session used to perform the commit
318309 """
319310
320- committed = None
321-
322311 def __init__ (self , session ):
323312 super (MutationGroups , self ).__init__ (session )
324- self ._mutation_groups = []
325-
326- def _check_state (self ):
327- """Checks if the object's state is valid for making API requests.
328-
329- :raises: :exc:`ValueError` if the object's state is invalid for making
330- API requests.
331- """
332- if self .committed is not None :
333- raise ValueError ("MutationGroups already committed" )
313+ self ._mutation_groups : List [MutationGroup ] = []
314+ self .committed : bool = False
334315
335316 def group (self ):
336317 """Returns a new `MutationGroup` to which mutations can be added."""
@@ -358,57 +339,62 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
358339 :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
359340 :returns: a sequence of responses for each batch.
360341 """
361- self ._check_state ()
362342
363- database = self ._session ._database
343+ if self .committed :
344+ raise ValueError ("MutationGroups already committed" )
345+
346+ mutation_groups = self ._mutation_groups
347+ session = self ._session
348+ database = session ._database
364349 api = database .spanner_api
350+
365351 metadata = _metadata_with_prefix (database .name )
366352 if database ._route_to_leader_enabled :
367353 metadata .append (
368354 _metadata_with_leader_aware_routing (database ._route_to_leader_enabled )
369355 )
370- trace_attributes = { "num_mutation_groups" : len ( self . _mutation_groups )}
356+
371357 if request_options is None :
372358 request_options = RequestOptions ()
373359 elif type (request_options ) is dict :
374360 request_options = RequestOptions (request_options )
375361
376- request = BatchWriteRequest (
377- session = self ._session .name ,
378- mutation_groups = self ._mutation_groups ,
379- request_options = request_options ,
380- exclude_txn_from_change_streams = exclude_txn_from_change_streams ,
381- )
382- observability_options = getattr (database , "observability_options" , None )
383362 with trace_call (
384- "CloudSpanner.batch_write" ,
385- self . _session ,
386- trace_attributes ,
387- observability_options = observability_options ,
363+ name = "CloudSpanner.batch_write" ,
364+ session = session ,
365+ extra_attributes = { "num_mutation_groups" : len ( mutation_groups )} ,
366+ observability_options = getattr ( database , " observability_options" , None ) ,
388367 metadata = metadata ,
389368 ) as span , MetricsCapture ():
390369 attempt = AtomicCounter (0 )
391370 nth_request = getattr (database , "_next_nth_request" , 0 )
392371
393- def wrapped_method (* args , ** kwargs ):
394- method = functools .partial (
372+ def wrapped_method ():
373+ batch_write_request = BatchWriteRequest (
374+ session = session .name ,
375+ mutation_groups = mutation_groups ,
376+ request_options = request_options ,
377+ exclude_txn_from_change_streams = exclude_txn_from_change_streams ,
378+ )
379+ batch_write_method = functools .partial (
395380 api .batch_write ,
396- request = request ,
381+ request = batch_write_request ,
397382 metadata = database .metadata_with_request_id (
398383 nth_request ,
399384 attempt .increment (),
400385 metadata ,
401386 span ,
402387 ),
403388 )
404- return method ( * args , ** kwargs )
389+ return batch_write_method ( )
405390
406391 response = _retry (
407392 wrapped_method ,
408393 allowed_exceptions = {
409394 InternalServerError : _check_rst_stream_error ,
410395 },
411396 )
397+
412398 self .committed = True
413399 return response
414400
0 commit comments