Skip to content

Commit 5ef6394

Browse files
committed
fix: add support for more parameter types
1 parent bff0b73 commit 5ef6394

5 files changed

Lines changed: 307 additions & 3 deletions

File tree

packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
import base64
1616
import datetime
17+
import decimal
1718
import gzip
1819
import json
20+
import uuid
1921
from dataclasses import dataclass
22+
from google.cloud.spanner_v1.data_types import Interval, JsonObject
2023
from typing import Any
2124

2225
from google.protobuf.json_format import MessageToDict, ParseDict
@@ -36,6 +39,14 @@ def _serialize_value(val: Any) -> Any:
3639
return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")}
3740
elif isinstance(val, datetime.datetime):
3841
return {"__type__": "datetime", "value": val.isoformat()}
42+
elif isinstance(val, datetime.date):
43+
return {"__type__": "date", "value": val.isoformat()}
44+
elif isinstance(val, decimal.Decimal):
45+
return {"__type__": "decimal", "value": str(val)}
46+
elif isinstance(val, uuid.UUID):
47+
return {"__type__": "uuid", "value": str(val)}
48+
elif isinstance(val, Interval):
49+
return {"__type__": "interval", "value": str(val)}
3950
elif hasattr(val, "_pb"):
4051
return {
4152
"__type__": "protobuf",
@@ -48,6 +59,8 @@ def _serialize_value(val: Any) -> Any:
4859
"class": val.__class__.__name__,
4960
"value": MessageToDict(val, preserving_proto_field_name=True),
5061
}
62+
elif isinstance(val, JsonObject):
63+
return {"__type__": "json_object", "value": val.serialize()}
5164
elif isinstance(val, dict):
5265
return {k: _serialize_value(v) for k, v in val.items()}
5366
elif isinstance(val, list):
@@ -68,6 +81,16 @@ def _deserialize_value(val: Any) -> Any:
6881
if dt_str.endswith("Z"):
6982
dt_str = dt_str[:-1] + "+00:00"
7083
return datetime.datetime.fromisoformat(dt_str)
84+
elif t == "date":
85+
return datetime.date.fromisoformat(val["value"])
86+
elif t == "decimal":
87+
return decimal.Decimal(val["value"])
88+
elif t == "uuid":
89+
return uuid.UUID(val["value"])
90+
elif t == "interval":
91+
return Interval.from_str(val["value"])
92+
elif t == "json_object":
93+
return JsonObject.from_str(val["value"])
7194
elif t == "tuple":
7295
return tuple(_deserialize_value(x) for x in val["value"])
7396
elif t == "protobuf":

packages/google-cloud-spanner/google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ class MockSpanner:
3636
def __init__(self):
3737
self.results = {}
3838
self.execute_streaming_sql_results = {}
39+
self.partition_results = {}
3940
self.errors = {}
4041

4142
def clear_results(self):
4243
self.results = {}
4344
self.execute_streaming_sql_results = {}
45+
self.partition_results = {}
4446
self.errors = {}
4547

4648
def add_result(self, sql: str, result: result_set.ResultSet):
4749
self.results[sql.lower().strip()] = result
4850

51+
def add_partition_result(self, sql: str, result: spanner.PartitionResponse):
52+
self.partition_results[sql.lower().strip()] = result
53+
4954
def add_execute_streaming_sql_results(
5055
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
5156
):
@@ -57,6 +62,12 @@ def get_result(self, sql: str) -> result_set.ResultSet:
5762
raise ValueError(f"No result found for {sql}")
5863
return result
5964

65+
def get_partition_result(self, sql: str) -> spanner.PartitionResponse:
66+
result = self.partition_results.get(sql.lower().strip())
67+
if result is None:
68+
return spanner.PartitionResponse()
69+
return result
70+
6071
def add_error(self, method: str, error: _Status):
6172
if not hasattr(self, "_errors_list"):
6273
self._errors_list = {}
@@ -300,11 +311,12 @@ def Rollback(self, request, context):
300311

301312
def PartitionQuery(self, request, context):
302313
self._requests.append(request)
303-
return spanner.PartitionResponse()
314+
return self.mock_spanner.get_partition_result(request.sql)
304315

305316
def PartitionRead(self, request, context):
306317
self._requests.append(request)
307-
return spanner.PartitionResponse()
318+
# For reads, look up by target table name
319+
return self.mock_spanner.get_partition_result(request.table)
308320

309321
def BatchWrite(self, request, context):
310322
self._requests.append(request)

packages/google-cloud-spanner/tests/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from os import getenv
22
from unittest import IsolatedAsyncioTestCase
33

4-
import mock
4+
try:
5+
import mock
6+
except ImportError:
7+
import unittest.mock as mock
58

69
from google.cloud.spanner_v1 import gapic_version
710
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from google.cloud.spanner_dbapi.connection import Connection
17+
from google.cloud.spanner_v1.types import spanner as spanner_types
18+
from google.cloud.spanner_v1 import TypeCode
19+
from tests.mockserver_tests.mock_server_test_base import MockServerTestBase, add_single_result
20+
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
21+
22+
23+
class TestDbapiPartitionQuery(MockServerTestBase):
24+
def test_partition_query_and_run_partition(self):
25+
sql = "SELECT name FROM users WHERE active = true"
26+
27+
# 1. Set up mock results for PartitionQuery RPC in the mock servicer
28+
partition_response = spanner_types.PartitionResponse()
29+
partition_response.partitions.extend([
30+
spanner_types.Partition(partition_token=b"mock-token-1"),
31+
spanner_types.Partition(partition_token=b"mock-token-2")
32+
])
33+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
34+
35+
# 2. Set up mock results for ExecuteSql when executing the partitions
36+
add_single_result(sql, "name", TypeCode.STRING, [("Alice",), ("Bob",)])
37+
38+
# 3. Connect via DB-API and mark connection as read-only (required for partitioning)
39+
connection = Connection(self.instance, self.database)
40+
connection._read_only = True
41+
42+
# Define partitioning parameters inside DB-API Statement
43+
from google.cloud.spanner_dbapi.parsed_statement import StatementType, ClientSideStatementType
44+
parsed = ParsedStatement(
45+
statement_type=StatementType.CLIENT_SIDE,
46+
statement=Statement(sql),
47+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
48+
client_side_statement_params=["SELECT name FROM users WHERE active = true"]
49+
)
50+
51+
# Generate serialized token strings (Base64 + GZip JSON)
52+
partition_ids = connection.partition_query(parsed)
53+
self.assertEqual(2, len(partition_ids))
54+
55+
# 4. Reconstruct & Execute the partitions by deserializing their tokens
56+
all_names = []
57+
for token in partition_ids:
58+
result_stream = connection.run_partition(token)
59+
for row in result_stream:
60+
all_names.append(row[0])
61+
62+
# Verify results are successfully round-tripped and parsed
63+
self.assertIn("Alice", all_names)
64+
self.assertIn("Bob", all_names)
65+
66+
def test_partition_query_with_complex_parameters(self):
67+
import decimal
68+
import datetime
69+
70+
sql = "SELECT name FROM users WHERE active = @active AND salary > @salary AND signup_time = @signup_time"
71+
72+
# Set up complex parameter values (bool, Decimal, datetime)
73+
params = {
74+
"active": True,
75+
"salary": decimal.Decimal("75000.50"),
76+
"signup_time": datetime.datetime(2026, 5, 10, 12, 34, 56, tzinfo=datetime.timezone.utc)
77+
}
78+
from google.cloud.spanner_v1 import Type
79+
param_types = {
80+
"active": Type(code=TypeCode.BOOL),
81+
"salary": Type(code=TypeCode.NUMERIC),
82+
"signup_time": Type(code=TypeCode.TIMESTAMP)
83+
}
84+
85+
# 1. Mock results for the partition generation RPC
86+
partition_response = spanner_types.PartitionResponse()
87+
partition_response.partitions.extend([
88+
spanner_types.Partition(partition_token=b"complex-mock-token-1")
89+
])
90+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
91+
92+
# 2. Mock results for execution of partition streaming SQL
93+
add_single_result(sql, "name", TypeCode.STRING, [("Charlie",)])
94+
95+
# 3. Establish Connection
96+
connection = Connection(self.instance, self.database)
97+
connection._read_only = True
98+
99+
from google.cloud.spanner_dbapi.parsed_statement import StatementType, ClientSideStatementType
100+
parsed = ParsedStatement(
101+
statement_type=StatementType.CLIENT_SIDE,
102+
statement=Statement(sql, params=params, param_types=param_types),
103+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
104+
client_side_statement_params=[sql]
105+
)
106+
107+
# Execute partition generation - this serializes query parameters!
108+
partition_ids = connection.partition_query(parsed)
109+
self.assertEqual(1, len(partition_ids))
110+
111+
# 4. Reconstruct and run the partition E2E
112+
all_names = []
113+
for token in partition_ids:
114+
result_stream = connection.run_partition(token)
115+
for row in result_stream:
116+
all_names.append(row[0])
117+
118+
self.assertEqual(["Charlie"], all_names)

packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_partition_helper.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,151 @@ def test_insecure_deserialization_failure(self):
122122
# Since we now use json.loads, a pickle payload will fail to decode as UTF-8 / JSON
123123
with self.assertRaises((json.JSONDecodeError, UnicodeDecodeError)):
124124
partition_helper.decode_from_string(encoded_pickle)
125+
126+
def test_protobuf_round_trip_reversibility(self):
127+
# Any Protobuf message returned by Spanner options must be fully and reversibly
128+
# round-trip deserializable to its original class, not falling back to a dict.
129+
for cls_name, cls in partition_helper._PROTO_CLASS_MAP.items():
130+
instance = cls()
131+
serialized = partition_helper._serialize_value(instance)
132+
deserialized = partition_helper._deserialize_value(serialized)
133+
self.assertIsInstance(deserialized, cls)
134+
self.assertEqual(deserialized.__class__.__name__, cls_name)
135+
136+
def test_dynamic_partition_options_registered(self):
137+
# Dynamically verify that any Protobuf message class generated inside query_info or read_info
138+
# during partitioning is registered in _PROTO_CLASS_MAP.
139+
from google.protobuf.message import Message
140+
from unittest.mock import MagicMock
141+
from google.cloud.spanner_v1.database import BatchSnapshot
142+
from google.cloud.spanner_v1.types import ExecuteSqlRequest, DirectedReadOptions
143+
144+
db = MagicMock()
145+
db.observability_options = {}
146+
db._instance._client._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1")
147+
148+
snapshot = BatchSnapshot(db)
149+
snapshot._snapshot = MagicMock()
150+
snapshot._snapshot.partition_query.return_value = [b"token-123"]
151+
snapshot._snapshot.partition_read.return_value = [b"token-456"]
152+
153+
query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="2")
154+
directed_read_options = DirectedReadOptions()
155+
156+
query_batches = list(snapshot.generate_query_batches(
157+
sql="SELECT 1",
158+
query_options=query_options,
159+
directed_read_options=directed_read_options
160+
))
161+
162+
from google.cloud.spanner_v1.keyset import KeySet
163+
read_batches = list(snapshot.generate_read_batches(
164+
table="users",
165+
columns=["name"],
166+
keyset=KeySet(all_=True),
167+
directed_read_options=directed_read_options
168+
))
169+
170+
discovered_protobuf_classes = set()
171+
172+
def collect_protobufs(val):
173+
if isinstance(val, dict):
174+
for v in val.values():
175+
collect_protobufs(v)
176+
elif isinstance(val, list):
177+
for v in val:
178+
collect_protobufs(v)
179+
elif hasattr(val, "_pb") or isinstance(val, Message):
180+
discovered_protobuf_classes.add(val.__class__)
181+
182+
for batch in query_batches + read_batches:
183+
collect_protobufs(batch)
184+
185+
registered_classes = set(partition_helper._PROTO_CLASS_MAP.values())
186+
for cls in discovered_protobuf_classes:
187+
with self.subTest(cls=cls):
188+
self.assertIn(
189+
cls,
190+
registered_classes,
191+
f"Protobuf class '{cls.__name__}' is generated in partition batch details "
192+
f"but is not registered in partition_helper._PROTO_CLASS_MAP! "
193+
f"Please add it to _PROTO_CLASS_MAP to prevent silent deserialization failures."
194+
)
195+
196+
def test_all_spanner_param_types_round_trip(self):
197+
import uuid
198+
import datetime
199+
import decimal
200+
from google.cloud.spanner_v1.data_types import Interval, JsonObject
201+
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
202+
203+
complex_params = {
204+
"uuid_val": uuid.UUID("12345678-1234-5678-1234-567812345678"),
205+
"date_val": datetime.date(2026, 5, 12),
206+
"decimal_val": decimal.Decimal("99999.99"),
207+
"interval_val": Interval(months=35, days=12, nanos=54321000),
208+
"json_val": JsonObject({"name": "Alice", "active": True}),
209+
"timestamp_val": datetime.datetime(2026, 5, 12, 12, 34, 56, tzinfo=datetime.timezone.utc),
210+
"timestamp_nanos_val": DatetimeWithNanoseconds(2026, 5, 12, 12, 34, 56, 123456, tzinfo=datetime.timezone.utc),
211+
"bytes_val": b"binary-data",
212+
"bool_val": True,
213+
"int_val": 100,
214+
"float_val": 123.45,
215+
"str_val": "hello-world",
216+
"none_val": None
217+
}
218+
219+
# DYNAMIC AST VERIFICATION OF CORE SDK SUPPORTED TYPES
220+
# Dynamically discover all classes checked by `isinstance` inside Spanner's `_make_value_pb`.
221+
import ast
222+
import inspect
223+
from google.cloud.spanner_v1._helpers import _make_value_pb
224+
225+
source = inspect.getsource(_make_value_pb)
226+
tree = ast.parse(source)
227+
228+
discovered_types = set()
229+
for node in ast.walk(tree):
230+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "isinstance":
231+
if len(node.args) >= 2 and isinstance(node.args[0], ast.Name) and node.args[0].id == "value":
232+
type_arg = node.args[1]
233+
if isinstance(type_arg, ast.Tuple):
234+
for elt in type_arg.elts:
235+
if isinstance(elt, ast.Name):
236+
discovered_types.add(elt.id)
237+
elif isinstance(type_arg, ast.Name):
238+
discovered_types.add(type_arg.id)
239+
elif isinstance(type_arg, ast.Attribute):
240+
discovered_types.add(type_arg.attr)
241+
242+
# Map our test's `complex_params` actual class/instance types to the class name strings
243+
test_param_class_names = set()
244+
for val in complex_params.values():
245+
if val is not None:
246+
test_param_class_names.add(val.__class__.__name__)
247+
# Also map base classes (e.g., DatetimeWithNanoseconds inherits from datetime)
248+
for base in val.__class__.__mro__:
249+
test_param_class_names.add(base.__name__)
250+
251+
# Special mappings for primitive built-ins checked in standard library or tuple serialization
252+
test_param_class_names.update({"list", "tuple", "Message", "ListValue"})
253+
254+
# Assert that every type validated in Spanner SDK's _make_value_pb
255+
# has a corresponding test parameter type implemented in our round-trip check
256+
for sdk_type in discovered_types:
257+
with self.subTest(sdk_type=sdk_type):
258+
self.assertIn(
259+
sdk_type,
260+
test_param_class_names,
261+
f"Spanner SDK parameter helper (_make_value_pb) supports type '{sdk_type}', "
262+
f"but this type has not been implemented/mapped in the DB-API partition "
263+
f"helper tests! Please add a verification case for it."
264+
)
265+
266+
# For each parameter type, try round-trip serialization
267+
for key, val in complex_params.items():
268+
with self.subTest(type_key=key):
269+
serialized = partition_helper._serialize_value(val)
270+
json_str = json.dumps(serialized)
271+
deserialized = partition_helper._deserialize_value(json.loads(json_str))
272+
self.assertEqual(deserialized, val, f"Round-trip failed for {key}! Original: {val}, Deserialized: {deserialized}")

0 commit comments

Comments
 (0)