@@ -122,3 +122,151 @@ def test_insecure_deserialization_failure(self):
122122 # Since we now use json.loads, a pickle payload will fail to decode as UTF-8 / JSON
123123 with self .assertRaises ((json .JSONDecodeError , UnicodeDecodeError )):
124124 partition_helper .decode_from_string (encoded_pickle )
125+
126+ def test_protobuf_round_trip_reversibility (self ):
127+ # Any Protobuf message returned by Spanner options must be fully and reversibly
128+ # round-trip deserializable to its original class, not falling back to a dict.
129+ for cls_name , cls in partition_helper ._PROTO_CLASS_MAP .items ():
130+ instance = cls ()
131+ serialized = partition_helper ._serialize_value (instance )
132+ deserialized = partition_helper ._deserialize_value (serialized )
133+ self .assertIsInstance (deserialized , cls )
134+ self .assertEqual (deserialized .__class__ .__name__ , cls_name )
135+
136+ def test_dynamic_partition_options_registered (self ):
137+ # Dynamically verify that any Protobuf message class generated inside query_info or read_info
138+ # during partitioning is registered in _PROTO_CLASS_MAP.
139+ from google .protobuf .message import Message
140+ from unittest .mock import MagicMock
141+ from google .cloud .spanner_v1 .database import BatchSnapshot
142+ from google .cloud .spanner_v1 .types import ExecuteSqlRequest , DirectedReadOptions
143+
144+ db = MagicMock ()
145+ db .observability_options = {}
146+ db ._instance ._client ._query_options = ExecuteSqlRequest .QueryOptions (optimizer_version = "1" )
147+
148+ snapshot = BatchSnapshot (db )
149+ snapshot ._snapshot = MagicMock ()
150+ snapshot ._snapshot .partition_query .return_value = [b"token-123" ]
151+ snapshot ._snapshot .partition_read .return_value = [b"token-456" ]
152+
153+ query_options = ExecuteSqlRequest .QueryOptions (optimizer_version = "2" )
154+ directed_read_options = DirectedReadOptions ()
155+
156+ query_batches = list (snapshot .generate_query_batches (
157+ sql = "SELECT 1" ,
158+ query_options = query_options ,
159+ directed_read_options = directed_read_options
160+ ))
161+
162+ from google .cloud .spanner_v1 .keyset import KeySet
163+ read_batches = list (snapshot .generate_read_batches (
164+ table = "users" ,
165+ columns = ["name" ],
166+ keyset = KeySet (all_ = True ),
167+ directed_read_options = directed_read_options
168+ ))
169+
170+ discovered_protobuf_classes = set ()
171+
172+ def collect_protobufs (val ):
173+ if isinstance (val , dict ):
174+ for v in val .values ():
175+ collect_protobufs (v )
176+ elif isinstance (val , list ):
177+ for v in val :
178+ collect_protobufs (v )
179+ elif hasattr (val , "_pb" ) or isinstance (val , Message ):
180+ discovered_protobuf_classes .add (val .__class__ )
181+
182+ for batch in query_batches + read_batches :
183+ collect_protobufs (batch )
184+
185+ registered_classes = set (partition_helper ._PROTO_CLASS_MAP .values ())
186+ for cls in discovered_protobuf_classes :
187+ with self .subTest (cls = cls ):
188+ self .assertIn (
189+ cls ,
190+ registered_classes ,
191+ f"Protobuf class '{ cls .__name__ } ' is generated in partition batch details "
192+ f"but is not registered in partition_helper._PROTO_CLASS_MAP! "
193+ f"Please add it to _PROTO_CLASS_MAP to prevent silent deserialization failures."
194+ )
195+
196+ def test_all_spanner_param_types_round_trip (self ):
197+ import uuid
198+ import datetime
199+ import decimal
200+ from google .cloud .spanner_v1 .data_types import Interval , JsonObject
201+ from google .api_core .datetime_helpers import DatetimeWithNanoseconds
202+
203+ complex_params = {
204+ "uuid_val" : uuid .UUID ("12345678-1234-5678-1234-567812345678" ),
205+ "date_val" : datetime .date (2026 , 5 , 12 ),
206+ "decimal_val" : decimal .Decimal ("99999.99" ),
207+ "interval_val" : Interval (months = 35 , days = 12 , nanos = 54321000 ),
208+ "json_val" : JsonObject ({"name" : "Alice" , "active" : True }),
209+ "timestamp_val" : datetime .datetime (2026 , 5 , 12 , 12 , 34 , 56 , tzinfo = datetime .timezone .utc ),
210+ "timestamp_nanos_val" : DatetimeWithNanoseconds (2026 , 5 , 12 , 12 , 34 , 56 , 123456 , tzinfo = datetime .timezone .utc ),
211+ "bytes_val" : b"binary-data" ,
212+ "bool_val" : True ,
213+ "int_val" : 100 ,
214+ "float_val" : 123.45 ,
215+ "str_val" : "hello-world" ,
216+ "none_val" : None
217+ }
218+
219+ # DYNAMIC AST VERIFICATION OF CORE SDK SUPPORTED TYPES
220+ # Dynamically discover all classes checked by `isinstance` inside Spanner's `_make_value_pb`.
221+ import ast
222+ import inspect
223+ from google .cloud .spanner_v1 ._helpers import _make_value_pb
224+
225+ source = inspect .getsource (_make_value_pb )
226+ tree = ast .parse (source )
227+
228+ discovered_types = set ()
229+ for node in ast .walk (tree ):
230+ if isinstance (node , ast .Call ) and isinstance (node .func , ast .Name ) and node .func .id == "isinstance" :
231+ if len (node .args ) >= 2 and isinstance (node .args [0 ], ast .Name ) and node .args [0 ].id == "value" :
232+ type_arg = node .args [1 ]
233+ if isinstance (type_arg , ast .Tuple ):
234+ for elt in type_arg .elts :
235+ if isinstance (elt , ast .Name ):
236+ discovered_types .add (elt .id )
237+ elif isinstance (type_arg , ast .Name ):
238+ discovered_types .add (type_arg .id )
239+ elif isinstance (type_arg , ast .Attribute ):
240+ discovered_types .add (type_arg .attr )
241+
242+ # Map our test's `complex_params` actual class/instance types to the class name strings
243+ test_param_class_names = set ()
244+ for val in complex_params .values ():
245+ if val is not None :
246+ test_param_class_names .add (val .__class__ .__name__ )
247+ # Also map base classes (e.g., DatetimeWithNanoseconds inherits from datetime)
248+ for base in val .__class__ .__mro__ :
249+ test_param_class_names .add (base .__name__ )
250+
251+ # Special mappings for primitive built-ins checked in standard library or tuple serialization
252+ test_param_class_names .update ({"list" , "tuple" , "Message" , "ListValue" })
253+
254+ # Assert that every type validated in Spanner SDK's _make_value_pb
255+ # has a corresponding test parameter type implemented in our round-trip check
256+ for sdk_type in discovered_types :
257+ with self .subTest (sdk_type = sdk_type ):
258+ self .assertIn (
259+ sdk_type ,
260+ test_param_class_names ,
261+ f"Spanner SDK parameter helper (_make_value_pb) supports type '{ sdk_type } ', "
262+ f"but this type has not been implemented/mapped in the DB-API partition "
263+ f"helper tests! Please add a verification case for it."
264+ )
265+
266+ # For each parameter type, try round-trip serialization
267+ for key , val in complex_params .items ():
268+ with self .subTest (type_key = key ):
269+ serialized = partition_helper ._serialize_value (val )
270+ json_str = json .dumps (serialized )
271+ deserialized = partition_helper ._deserialize_value (json .loads (json_str ))
272+ self .assertEqual (deserialized , val , f"Round-trip failed for { key } ! Original: { val } , Deserialized: { deserialized } " )
0 commit comments