1515"""Helpers for applying Google Cloud Firestore changes in a transaction."""
1616from __future__ import annotations
1717
18- from typing import TYPE_CHECKING , Any , AsyncGenerator , Callable , Coroutine , Optional
18+ from typing import TYPE_CHECKING , Any , AsyncGenerator , Awaitable , Callable , Coroutine , Optional , TypeVar , ParamSpec , Concatenate
1919
2020from google .api_core import exceptions , gapic_v1
2121from google .api_core import retry_async as retries
4141 from google .cloud .firestore_v1 .query_profile import ExplainOptions
4242
4343
44+ T = TypeVar ("T" )
45+ P = ParamSpec ("P" )
46+
4447class AsyncTransaction (async_batch .AsyncWriteBatch , BaseTransaction ):
4548 """Accumulate read-and-write operations to be sent in a transaction.
4649
@@ -236,11 +239,11 @@ class _AsyncTransactional(_BaseTransactional):
236239 A coroutine that should be run (and retried) in a transaction.
237240 """
238241
239- def __init__ (self , to_wrap ) -> None :
242+ def __init__ (self , to_wrap : Callable [ Concatenate [ AsyncTransaction , P ], Awaitable [ T ]] ) -> None :
240243 super (_AsyncTransactional , self ).__init__ (to_wrap )
241244
242245 async def _pre_commit (
243- self , transaction : AsyncTransaction , * args , ** kwargs
246+ self , transaction : AsyncTransaction , * args : P . args , ** kwargs : P . kwargs
244247 ) -> Coroutine :
245248 """Begin transaction and call the wrapped coroutine.
246249
@@ -254,7 +257,7 @@ async def _pre_commit(
254257 along to the wrapped coroutine.
255258
256259 Returns:
257- Any : result of the wrapped coroutine.
260+ T : result of the wrapped coroutine.
258261
259262 Raises:
260263 Exception: Any failure caused by ``to_wrap``.
@@ -269,20 +272,20 @@ async def _pre_commit(
269272 self .retry_id = self .current_id
270273 return await self .to_wrap (transaction , * args , ** kwargs )
271274
272- async def __call__ (self , transaction , * args , ** kwargs ) :
275+ async def __call__ (self , transaction : AsyncTransaction , * args : P . args , ** kwargs : P . kwargs ) -> T :
273276 """Execute the wrapped callable within a transaction.
274277
275278 Args:
276279 transaction
277- (:class:`~google.cloud.firestore_v1.transaction.Transaction `):
280+ (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `):
278281 A transaction to execute the callable within.
279282 args (Tuple[Any, ...]): The extra positional arguments to pass
280283 along to the wrapped callable.
281284 kwargs (Dict[str, Any]): The extra keyword arguments to pass
282285 along to the wrapped callable.
283286
284287 Returns:
285- Any : The result of the wrapped callable.
288+ T : The result of the wrapped callable.
286289
287290 Raises:
288291 ValueError: If the transaction does not succeed in
@@ -320,14 +323,12 @@ async def __call__(self, transaction, *args, **kwargs):
320323 raise
321324
322325
323- def async_transactional (
324- to_wrap : Callable [[AsyncTransaction ], Any ]
325- ) -> _AsyncTransactional :
326+ def async_transactional (to_wrap : Callable [Concatenate [AsyncTransaction , P ], Awaitable [T ]]) -> Callable [Concatenate [AsyncTransaction , P ], Awaitable [T ]]:
326327 """Decorate a callable so that it runs in a transaction.
327328
328329 Args:
329330 to_wrap
330- (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction `, ...], Any]):
331+ (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `, ...], Any]):
331332 A callable that should be run (and retried) in a transaction.
332333
333334 Returns:
0 commit comments