Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 35dba42

Browse files
committed
chore: move ClientContext validation to a helper function
1 parent 424223d commit 35dba42

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,25 @@ def _merge_client_context(base, merge):
230230
return combined
231231

232232

233+
def _validate_client_context(client_context):
234+
"""Validate and convert client_context.
235+
236+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
237+
or :class:`dict`
238+
:param client_context: (Optional) Client context to use.
239+
240+
:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
241+
:returns: Validated ClientContext object or None.
242+
:raises TypeError: if client_context is not a ClientContext or a dict.
243+
"""
244+
if client_context is not None:
245+
if isinstance(client_context, dict):
246+
client_context = ClientContext(client_context)
247+
elif not isinstance(client_context, ClientContext):
248+
raise TypeError("client_context must be a ClientContext or a dict")
249+
return client_context
250+
251+
233252
def _merge_request_options(request_options, client_context):
234253
"""Merge RequestOptions and ClientContext.
235254

google/cloud/spanner_v1/batch.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_merge_Transaction_Options,
3131
_merge_client_context,
3232
_merge_request_options,
33+
_validate_client_context,
3334
AtomicCounter,
3435
)
3536
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
@@ -66,13 +67,7 @@ def __init__(self, session, client_context=None):
6667
self.committed = None
6768
"""Timestamp at which the batch was successfully committed."""
6869
self.commit_stats: Optional[CommitResponse.CommitStats] = None
69-
70-
if client_context is not None:
71-
if isinstance(client_context, dict):
72-
client_context = ClientContext(client_context)
73-
elif not isinstance(client_context, ClientContext):
74-
raise TypeError("client_context must be a ClientContext or a dict")
75-
self._client_context = client_context
70+
self._client_context = _validate_client_context(client_context)
7671

7772
def insert(self, table, columns, values):
7873
"""Insert one or more new table rows.

google/cloud/spanner_v1/client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from google.cloud.spanner_v1.types import ClientContext
5454
from google.cloud.spanner_v1._helpers import _merge_query_options
5555
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
56+
from google.cloud.spanner_v1._helpers import _validate_client_context
5657
from google.cloud.spanner_v1.instance import Instance
5758
from google.cloud.spanner_v1.metrics.constants import (
5859
METRIC_EXPORT_INTERVAL_MS,
@@ -293,13 +294,7 @@ def __init__(
293294

294295
# Environment flag config has higher precedence than application config.
295296
self._query_options = _merge_query_options(query_options, env_query_options)
296-
297-
if client_context is not None:
298-
if isinstance(client_context, dict):
299-
client_context = ClientContext(client_context)
300-
elif not isinstance(client_context, ClientContext):
301-
raise TypeError("client_context must be a ClientContext or a dict")
302-
self._client_context = client_context
297+
self._client_context = _validate_client_context(client_context)
303298

304299
if self._emulator_host is not None and (
305300
"http://" in self._emulator_host or "https://" in self._emulator_host

google/cloud/spanner_v1/snapshot.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_SessionWrapper,
5151
AtomicCounter,
5252
_augment_error_with_request_id,
53+
_validate_client_context,
5354
)
5455
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event
5556
from google.cloud.spanner_v1.streamed import StreamedResultSet
@@ -212,13 +213,7 @@ class _SnapshotBase(_SessionWrapper):
212213
def __init__(self, session, client_context=None):
213214
super().__init__(session)
214215

215-
if client_context is not None:
216-
if isinstance(client_context, dict):
217-
client_context = ClientContext(client_context)
218-
elif not isinstance(client_context, ClientContext):
219-
raise TypeError("client_context must be a ClientContext or a dict")
220-
self._client_context = client_context
221-
216+
self._client_context = _validate_client_context(client_context)
222217
# Counts for execute SQL requests and total read requests (including
223218
# execute SQL requests). Used to provide sequence numbers for
224219
# :class:`google.cloud.spanner_v1.types.ExecuteSqlRequest` and to

0 commit comments

Comments
 (0)