Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 4335fbf

Browse files
committed
test: improve async test coverage
1 parent 8a91874 commit 4335fbf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+5867
-488
lines changed

google/cloud/aio/_cross_sync/cross_sync.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class CrossSync(metaclass=MappingMeta):
8989
LifoQueue: TypeAlias = asyncio.LifoQueue
9090
PriorityQueue: TypeAlias = asyncio.PriorityQueue
9191
StopIteration: TypeAlias = StopAsyncIteration
92+
QueueEmpty: TypeAlias = asyncio.QueueEmpty
93+
QueueFull: TypeAlias = asyncio.QueueFull
9294
# provide aliases for common async type annotations
9395
Awaitable: TypeAlias = typing.Awaitable
9496
Iterable: TypeAlias = AsyncIterable
@@ -160,17 +162,29 @@ async def run_if_async(func, *args, **kwargs):
160162
@staticmethod
161163
async def queue_get(queue, block=True, timeout=None):
162164
if not block:
163-
return queue.get_nowait()
165+
try:
166+
return queue.get_nowait()
167+
except asyncio.QueueEmpty:
168+
raise CrossSync.QueueEmpty()
164169
if timeout is not None:
165-
return await asyncio.wait_for(queue.get(), timeout=timeout)
170+
try:
171+
return await asyncio.wait_for(queue.get(), timeout=timeout)
172+
except asyncio.TimeoutError:
173+
raise CrossSync.QueueEmpty()
166174
return await queue.get()
167175

168176
@staticmethod
169177
async def queue_put(queue, item, block=True, timeout=None):
170178
if not block:
171-
return queue.put_nowait(item)
179+
try:
180+
return queue.put_nowait(item)
181+
except asyncio.QueueFull:
182+
raise CrossSync.QueueFull()
172183
if timeout is not None:
173-
await asyncio.wait_for(queue.put(item), timeout=timeout)
184+
try:
185+
await asyncio.wait_for(queue.put(item), timeout=timeout)
186+
except asyncio.TimeoutError:
187+
raise CrossSync.QueueFull()
174188
else:
175189
await queue.put(item)
176190

@@ -304,6 +318,8 @@ class _Sync_Impl(metaclass=MappingMeta):
304318
Semaphore: TypeAlias = threading.Semaphore
305319
LifoQueue: TypeAlias = queue.LifoQueue
306320
PriorityQueue: TypeAlias = queue.PriorityQueue
321+
QueueEmpty: TypeAlias = queue.Empty
322+
QueueFull: TypeAlias = queue.Full
307323
StopIteration: TypeAlias = StopIteration
308324
# type annotations
309325
Awaitable: TypeAlias = Union[T]

google/cloud/spanner_v1/_async/_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ async def _retry(
6060
if retries >= retry_count:
6161
raise e
6262
if before_next_retry:
63-
before_next_retry(retries, delay)
63+
res = before_next_retry(retries, delay)
64+
if asyncio.iscoroutine(res) or inspect.isawaitable(res):
65+
await res
6466
await asyncio.sleep(delay)
6567
retries += 1
6668

google/cloud/spanner_v1/_async/batch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@
2121
from google.api_core.exceptions import InternalServerError
2222

2323
from google.cloud.aio._cross_sync import CrossSync
24-
from google.cloud.spanner_v1 import (
25-
BatchWriteRequest,
26-
CommitRequest,
27-
CommitResponse,
28-
Mutation,
29-
RequestOptions,
30-
TransactionOptions,
31-
)
24+
3225
from google.cloud.spanner_v1._async._helpers import _retry, _retry_on_aborted_exception
3326
from google.cloud.spanner_v1._helpers import (
3427
AtomicCounter,
@@ -45,6 +38,14 @@
4538
)
4639
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
4740
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
41+
from google.cloud.spanner_v1.types.commit_response import CommitResponse
42+
from google.cloud.spanner_v1.types.mutation import Mutation
43+
from google.cloud.spanner_v1.types.spanner import (
44+
BatchWriteRequest,
45+
CommitRequest,
46+
RequestOptions,
47+
)
48+
from google.cloud.spanner_v1.types.transaction import TransactionOptions
4849

4950
DEFAULT_RETRY_TIMEOUT_SECS = 30
5051

google/cloud/spanner_v1/_async/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,22 @@
6666
ListInstanceConfigsRequest,
6767
ListInstancesRequest,
6868
)
69-
from google.cloud.spanner_v1.gapic_version import __version__
70-
from google.cloud.spanner_v1.transaction import DefaultTransactionOptions
71-
from google.cloud.spanner_v1.types.spanner import ExecuteSqlRequest
69+
from google.cloud.spanner_v1._async.instance import Instance
7270
from google.cloud.spanner_v1._helpers import (
7371
_create_experimental_host_transport,
7472
_validate_client_context,
7573
)
76-
from google.cloud.spanner_v1._async.instance import Instance
7774
from google.cloud.spanner_v1._helpers import _merge_query_options, _metadata_with_prefix
75+
from google.cloud.spanner_v1.gapic_version import __version__
7876
from google.cloud.spanner_v1.metrics.constants import METRIC_EXPORT_INTERVAL_MS
7977
from google.cloud.spanner_v1.metrics.metrics_exporter import (
8078
CloudMonitoringMetricsExporter,
8179
)
8280
from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import (
8381
SpannerMetricsTracerFactory,
8482
)
83+
from google.cloud.spanner_v1.transaction import DefaultTransactionOptions
84+
from google.cloud.spanner_v1.types.spanner import ExecuteSqlRequest
8585

8686
try:
8787
from opentelemetry import metrics

google/cloud/spanner_v1/_async/database.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,6 @@
4343
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
4444
from google.cloud.spanner_admin_database_v1 import Database as DatabasePB
4545
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
46-
from google.cloud.spanner_v1.transaction import DefaultTransactionOptions
47-
from google.cloud.spanner_v1.types.spanner import ExecuteSqlRequest
48-
from google.cloud.spanner_v1.types.spanner import RequestOptions
49-
from google.cloud.spanner_v1.types.transaction import TransactionOptions
50-
from google.cloud.spanner_v1.types.transaction import TransactionSelector
51-
from google.cloud.spanner_v1.types.type import Type
52-
from google.cloud.spanner_v1.types.type import TypeCode
5346
from google.cloud.spanner_v1._async.batch import Batch, MutationGroups
5447
from google.cloud.spanner_v1._async.database_sessions_manager import (
5548
DatabaseSessionsManager,
@@ -72,7 +65,16 @@
7265
from google.cloud.spanner_v1.services.spanner.async_client import (
7366
SpannerAsyncClient as SpannerClient,
7467
)
75-
from google.cloud.spanner_v1.transaction import BatchTransactionId
68+
from google.cloud.spanner_v1.transaction import (
69+
BatchTransactionId,
70+
DefaultTransactionOptions,
71+
)
72+
from google.cloud.spanner_v1.types.spanner import ExecuteSqlRequest, RequestOptions
73+
from google.cloud.spanner_v1.types.transaction import (
74+
TransactionOptions,
75+
TransactionSelector,
76+
)
77+
from google.cloud.spanner_v1.types.type import Type, TypeCode
7678

7779
if CrossSync.is_async:
7880
from google.cloud.spanner_v1.services.spanner.transports.grpc_asyncio import (
@@ -198,13 +200,22 @@ def __init__(
198200
self._encryption_config = encryption_config
199201
self._database_dialect = database_dialect
200202
self._database_role = database_role
201-
self._route_to_leader_enabled = self._instance._client.route_to_leader_enabled
203+
if self._instance and self._instance._client:
204+
self._route_to_leader_enabled = (
205+
self._instance._client.route_to_leader_enabled
206+
)
207+
else:
208+
self._route_to_leader_enabled = False
202209
self._enable_drop_protection = enable_drop_protection
203210
self._reconciling = False
204-
self._directed_read_options = self._instance._client.directed_read_options
205-
self.default_transaction_options: DefaultTransactionOptions = (
206-
self._instance._client.default_transaction_options
207-
)
211+
if self._instance and self._instance._client:
212+
self._directed_read_options = self._instance._client.directed_read_options
213+
self.default_transaction_options: DefaultTransactionOptions = (
214+
self._instance._client.default_transaction_options
215+
)
216+
else:
217+
self._directed_read_options = None
218+
self.default_transaction_options = None
208219
self._proto_descriptors = proto_descriptors
209220
self._channel_id = 0 # It'll be created when _spanner_api is created.
210221

@@ -220,7 +231,9 @@ def __init__(
220231
except RuntimeError:
221232
# No running loop, bind should have been sync or will be failed later
222233
pass
223-
self._experimental_host = self._instance._client._experimental_host
234+
self._experimental_host = (
235+
self._instance.experimental_host if self._instance else None
236+
)
224237
is_experimental_host = self._experimental_host is not None
225238

226239
self._sessions_manager = DatabaseSessionsManager(
@@ -231,8 +244,12 @@ def __init__(
231244
def _resource_info(self):
232245
"""Resource information for metrics labels."""
233246
return {
234-
"project": self._instance._client.project,
235-
"instance": self._instance.instance_id,
247+
"project": (
248+
self._instance._client.project
249+
if self._instance and self._instance._client
250+
else None
251+
),
252+
"instance": self._instance.instance_id if self._instance else None,
236253
"database": self.database_id,
237254
}
238255

@@ -1351,7 +1368,8 @@ async def list_tables(self, schema="_default"):
13511368
async for row in results:
13521369
yield self.table(row[0])
13531370

1354-
def get_iam_policy(self, policy_version=None):
1371+
@CrossSync.convert
1372+
async def get_iam_policy(self, policy_version=None):
13551373
"""Gets the access control policy for a database resource.
13561374
13571375
:type policy_version: int
@@ -1374,13 +1392,14 @@ def get_iam_policy(self, policy_version=None):
13741392
requested_policy_version=policy_version
13751393
),
13761394
)
1377-
response = api.get_iam_policy(
1395+
response = await api.get_iam_policy(
13781396
request=request,
13791397
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
13801398
)
13811399
return response
13821400

1383-
def set_iam_policy(self, policy):
1401+
@CrossSync.convert
1402+
async def set_iam_policy(self, policy):
13841403
"""Sets the access control policy on a database resource.
13851404
Replaces any existing policy.
13861405
@@ -1399,7 +1418,7 @@ def set_iam_policy(self, policy):
13991418
resource=self.name,
14001419
policy=policy,
14011420
)
1402-
response = api.set_iam_policy(
1421+
response = await api.set_iam_policy(
14031422
request=request,
14041423
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
14051424
)
@@ -1430,6 +1449,11 @@ def sessions_manager(self) -> DatabaseSessionsManager:
14301449
"""
14311450
return self._sessions_manager
14321451

1452+
@CrossSync.convert
1453+
async def close(self):
1454+
"""Clean up underlying session manager and background tasks."""
1455+
await self._sessions_manager.close()
1456+
14331457

14341458
class BatchCheckout(object):
14351459
"""Context manager for using a batch from a database.

google/cloud/spanner_v1/_async/database_sessions_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Manage sessions for a database."""
1616
__CROSS_SYNC_OUTPUT__ = "google.cloud.spanner_v1.database_sessions_manager"
1717

18+
import asyncio
1819
from datetime import timedelta
1920
from enum import Enum
2021
from os import getenv
@@ -215,3 +216,20 @@ def _getenv(cls, env_var_name: str) -> bool:
215216
"""Returns the value of the given environment variable as a boolean."""
216217
env_var_value = getenv(env_var_name, "true").lower().strip()
217218
return env_var_value != "false"
219+
220+
@CrossSync.convert
221+
async def close(self) -> None:
222+
"""Closes the database session manager and stops all background tasks."""
223+
self._multiplexed_session_terminate_event.set()
224+
if self._multiplexed_session_thread is not None:
225+
if CrossSync.is_async:
226+
self._multiplexed_session_thread.cancel()
227+
try:
228+
await self._multiplexed_session_thread
229+
except CrossSync.rm_aio(asyncio.CancelledError):
230+
pass
231+
else:
232+
self._multiplexed_session_thread.join()
233+
if self._multiplexed_session is not None:
234+
await self._multiplexed_session.delete()
235+
self._multiplexed_session = None

0 commit comments

Comments
 (0)