Skip to content

Commit bff0b73

Browse files
committed
fix(spanner_dbapi): replace insecure pickle with json for partition deserialization
1 parent 471eb13 commit bff0b73

2 files changed

Lines changed: 209 additions & 4 deletions

File tree

packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,104 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import datetime
1617
import gzip
17-
import pickle
18+
import json
1819
from dataclasses import dataclass
1920
from typing import Any
2021

22+
from google.protobuf.json_format import MessageToDict, ParseDict
23+
from google.protobuf.message import Message
24+
2125
from google.cloud.spanner_v1 import BatchTransactionId
26+
from google.cloud.spanner_v1.types import ExecuteSqlRequest, DirectedReadOptions
27+
28+
_PROTO_CLASS_MAP = {
29+
"QueryOptions": ExecuteSqlRequest.QueryOptions,
30+
"DirectedReadOptions": DirectedReadOptions,
31+
}
32+
33+
34+
def _serialize_value(val: Any) -> Any:
35+
if isinstance(val, bytes):
36+
return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")}
37+
elif isinstance(val, datetime.datetime):
38+
return {"__type__": "datetime", "value": val.isoformat()}
39+
elif hasattr(val, "_pb"):
40+
return {
41+
"__type__": "protobuf",
42+
"class": val.__class__.__name__,
43+
"value": MessageToDict(val._pb, preserving_proto_field_name=True),
44+
}
45+
elif isinstance(val, Message):
46+
return {
47+
"__type__": "protobuf",
48+
"class": val.__class__.__name__,
49+
"value": MessageToDict(val, preserving_proto_field_name=True),
50+
}
51+
elif isinstance(val, dict):
52+
return {k: _serialize_value(v) for k, v in val.items()}
53+
elif isinstance(val, list):
54+
return [_serialize_value(v) for v in val]
55+
elif isinstance(val, tuple):
56+
return {"__type__": "tuple", "value": [_serialize_value(v) for v in val]}
57+
return val
58+
59+
60+
def _deserialize_value(val: Any) -> Any:
61+
if isinstance(val, dict):
62+
if "__type__" in val:
63+
t = val["__type__"]
64+
if t == "bytes":
65+
return base64.b64decode(val["value"])
66+
elif t == "datetime":
67+
dt_str = val["value"]
68+
if dt_str.endswith("Z"):
69+
dt_str = dt_str[:-1] + "+00:00"
70+
return datetime.datetime.fromisoformat(dt_str)
71+
elif t == "tuple":
72+
return tuple(_deserialize_value(x) for x in val["value"])
73+
elif t == "protobuf":
74+
cls_name = val.get("class")
75+
dict_val = val["value"]
76+
if cls_name in _PROTO_CLASS_MAP:
77+
cls = _PROTO_CLASS_MAP[cls_name]
78+
msg = cls()._pb
79+
ParseDict(dict_val, msg)
80+
return cls(msg)
81+
return _deserialize_value(dict_val)
82+
return {k: _deserialize_value(v) for k, v in val.items()}
83+
elif isinstance(val, list):
84+
return [_deserialize_value(v) for v in val]
85+
return val
2286

2387

2488
def decode_from_string(encoded_partition_id):
2589
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
2690
partition_id_bytes = gzip.decompress(gzip_bytes)
27-
return pickle.loads(partition_id_bytes)
91+
92+
data = json.loads(partition_id_bytes.decode("utf-8"))
93+
btid_data = data["batch_transaction_id"]
94+
btid = BatchTransactionId(
95+
transaction_id=_deserialize_value(btid_data["transaction_id"]),
96+
session_id=btid_data["session_id"],
97+
read_timestamp=_deserialize_value(btid_data["read_timestamp"]),
98+
)
99+
partition_result = _deserialize_value(data["partition_result"])
100+
return PartitionId(btid, partition_result)
28101

29102

30103
def encode_to_string(batch_transaction_id, partition_result):
31-
partition_id = PartitionId(batch_transaction_id, partition_result)
32-
partition_id_bytes = pickle.dumps(partition_id)
104+
data = {
105+
"batch_transaction_id": {
106+
"transaction_id": _serialize_value(batch_transaction_id.transaction_id),
107+
"session_id": batch_transaction_id.session_id,
108+
"read_timestamp": _serialize_value(batch_transaction_id.read_timestamp),
109+
},
110+
"partition_result": _serialize_value(partition_result),
111+
}
112+
113+
partition_id_bytes = json.dumps(data).encode("utf-8")
33114
gzip_bytes = gzip.compress(partition_id_bytes)
34115
return str(base64.b64encode(gzip_bytes), "utf-8")
35116

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import datetime
17+
import gzip
18+
import json
19+
import unittest
20+
21+
from google.cloud.spanner_dbapi import partition_helper
22+
from google.cloud.spanner_v1 import BatchTransactionId
23+
from google.cloud.spanner_v1.types import ExecuteSqlRequest
24+
25+
26+
class TestPartitionHelper(unittest.TestCase):
27+
def test_encode_and_decode_success_query(self):
28+
btid = BatchTransactionId(
29+
transaction_id=b"test-txn-123",
30+
session_id="session-xyz",
31+
read_timestamp=datetime.datetime(
32+
2024, 5, 10, 12, 34, 56, tzinfo=datetime.timezone.utc
33+
),
34+
)
35+
36+
query_options = ExecuteSqlRequest.QueryOptions(
37+
optimizer_version="2",
38+
optimizer_statistics_package="package-abc",
39+
)
40+
41+
partition_result = {
42+
"partition": b"partition-token-456",
43+
"query": {
44+
"sql": "SELECT * FROM users WHERE age > %s",
45+
"params": {"age": 21},
46+
"query_options": query_options,
47+
},
48+
}
49+
50+
encoded = partition_helper.encode_to_string(btid, partition_result)
51+
self.assertIsInstance(encoded, str)
52+
53+
decoded = partition_helper.decode_from_string(encoded)
54+
self.assertIsInstance(decoded, partition_helper.PartitionId)
55+
56+
# Verify BatchTransactionId
57+
self.assertEqual(
58+
decoded.batch_transaction_id.transaction_id, btid.transaction_id
59+
)
60+
self.assertEqual(decoded.batch_transaction_id.session_id, btid.session_id)
61+
self.assertEqual(
62+
decoded.batch_transaction_id.read_timestamp, btid.read_timestamp
63+
)
64+
65+
# Verify partition result
66+
self.assertEqual(decoded.partition_result["partition"], b"partition-token-456")
67+
self.assertEqual(
68+
decoded.partition_result["query"]["sql"],
69+
"SELECT * FROM users WHERE age > %s",
70+
)
71+
self.assertEqual(decoded.partition_result["query"]["params"], {"age": 21})
72+
73+
# Verify query options (restored to object)
74+
opts_obj = decoded.partition_result["query"]["query_options"]
75+
self.assertEqual(opts_obj.optimizer_version, "2")
76+
self.assertEqual(opts_obj.optimizer_statistics_package, "package-abc")
77+
78+
def test_encode_and_decode_success_read(self):
79+
btid = BatchTransactionId(
80+
transaction_id=b"test-txn-456",
81+
session_id="session-abc",
82+
read_timestamp=None,
83+
)
84+
85+
partition_result = {
86+
"partition": b"partition-token-789",
87+
"read": {
88+
"table": "users",
89+
"columns": ["name", "age"],
90+
"keyset": {"keys": [[1], [2]]},
91+
},
92+
}
93+
94+
encoded = partition_helper.encode_to_string(btid, partition_result)
95+
decoded = partition_helper.decode_from_string(encoded)
96+
97+
self.assertEqual(
98+
decoded.batch_transaction_id.transaction_id, btid.transaction_id
99+
)
100+
self.assertEqual(decoded.batch_transaction_id.session_id, btid.session_id)
101+
self.assertIsNone(decoded.batch_transaction_id.read_timestamp)
102+
103+
self.assertEqual(decoded.partition_result["partition"], b"partition-token-789")
104+
self.assertEqual(decoded.partition_result["read"]["table"], "users")
105+
self.assertEqual(decoded.partition_result["read"]["columns"], ["name", "age"])
106+
self.assertEqual(
107+
decoded.partition_result["read"]["keyset"], {"keys": [[1], [2]]}
108+
)
109+
110+
def test_insecure_deserialization_failure(self):
111+
# Malicious payload that attempts to execute pickle.loads under old code
112+
# (Here, we'll just pass invalid JSON wrapped in gzip + base64, or a pickle payload,
113+
# and make sure it does NOT get deserialized or execute anything, but raises an error gracefully)
114+
115+
# A valid pickle payload for some simple object, base64 encoded and compressed
116+
import pickle
117+
118+
pickle_bytes = pickle.dumps({"test": "payload"})
119+
gzip_bytes = gzip.compress(pickle_bytes)
120+
encoded_pickle = base64.b64encode(gzip_bytes).decode("utf-8")
121+
122+
# Since we now use json.loads, a pickle payload will fail to decode as UTF-8 / JSON
123+
with self.assertRaises((json.JSONDecodeError, UnicodeDecodeError)):
124+
partition_helper.decode_from_string(encoded_pickle)

0 commit comments

Comments
 (0)