Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 280e436

Browse files
add mockspanner tests
1 parent 6b3a5e0 commit 280e436

3 files changed

Lines changed: 109 additions & 36 deletions

File tree

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@
3535
class MockSpanner:
3636
def __init__(self):
3737
self.results = {}
38+
self.execute_streaming_sql_results = {}
3839
self.errors = {}
3940

4041
def add_result(self, sql: str, result: result_set.ResultSet):
4142
self.results[sql.lower().strip()] = result
4243

44+
def add_execute_streaming_sql_results(self, sql: str,
45+
partial_result_sets: list[result_set.PartialResultSet]):
46+
self.execute_streaming_sql_results[
47+
sql.lower().strip()] = partial_result_sets
48+
4349
def get_result(self, sql: str) -> result_set.ResultSet:
4450
result = self.results.get(sql.lower().strip())
4551
if result is None:
@@ -55,9 +61,20 @@ def pop_error(self, context):
5561
if error:
5662
context.abort_with_status(error)
5763

64+
def get_execute_streaming_sql_results(self, sql: str,
65+
started_transaction: transaction.Transaction) -> list[
66+
result_set.PartialResultSet]:
67+
if self.execute_streaming_sql_results[sql.lower().strip()]:
68+
partials = self.execute_streaming_sql_results[sql.lower().strip()]
69+
else:
70+
partials = self.get_result_as_partial_result_sets(sql)
71+
if started_transaction:
72+
partials[0].metadata.transaction = started_transaction
73+
return partials
74+
5875
def get_result_as_partial_result_sets(
59-
self, sql: str, started_transaction: transaction.Transaction
60-
) -> [result_set.PartialResultSet]:
76+
self, sql: str
77+
) -> list[result_set.PartialResultSet]:
6178
result: result_set.ResultSet = self.get_result(sql)
6279
partials = []
6380
first = True
@@ -70,11 +87,10 @@ def get_result_as_partial_result_sets(
7087
partial = result_set.PartialResultSet()
7188
if first:
7289
partial.metadata = ResultSetMetadata(result.metadata)
90+
first = False
7391
partial.values.extend(row)
7492
partials.append(partial)
7593
partials[len(partials) - 1].stats = result.stats
76-
if started_transaction:
77-
partials[0].metadata.transaction = started_transaction
7894
return partials
7995

8096

@@ -149,7 +165,7 @@ def ExecuteStreamingSql(self, request, context):
149165
self._requests.append(request)
150166
self.mock_spanner.pop_error(context)
151167
started_transaction = self.__maybe_create_transaction(request)
152-
partials = self.mock_spanner.get_result_as_partial_result_sets(
168+
partials = self.mock_spanner.get_execute_streaming_sql_results(
153169
request.sql, started_transaction
154170
)
155171
for result in partials:

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,36 @@
1414

1515
import unittest
1616

17-
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
18-
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
19-
from google.cloud.spanner_v1.testing.mock_spanner import (
20-
start_mock_server,
21-
SpannerServicer,
22-
)
23-
import google.cloud.spanner_v1.types.type as spanner_type
24-
import google.cloud.spanner_v1.types.result_set as result_set
17+
import grpc
2518
from google.api_core.client_options import ClientOptions
2619
from google.auth.credentials import AnonymousCredentials
27-
from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool
28-
from google.cloud.spanner_v1.database import Database
29-
from google.cloud.spanner_v1.instance import Instance
30-
import grpc
20+
from google.cloud.spanner_v1 import Type
21+
22+
from google.cloud.spanner_v1 import StructType
23+
from google.cloud.spanner_v1._helpers import _make_value_pb
24+
25+
from google.cloud.spanner_v1 import PartialResultSet
26+
from google.protobuf.duration_pb2 import Duration
3127
from google.rpc import code_pb2
3228
from google.rpc import status_pb2
3329
from google.rpc.error_details_pb2 import RetryInfo
34-
from google.protobuf.duration_pb2 import Duration
3530
from grpc_status._common import code_to_grpc_status_code
3631
from grpc_status.rpc_status import _Status
3732

33+
import google.cloud.spanner_v1.types.result_set as result_set
34+
import google.cloud.spanner_v1.types.type as spanner_type
35+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
36+
from google.cloud.spanner_v1 import Client
37+
from google.cloud.spanner_v1 import FixedSizePool
38+
from google.cloud.spanner_v1 import ResultSetMetadata
39+
from google.cloud.spanner_v1 import TypeCode
40+
from google.cloud.spanner_v1.database import Database
41+
from google.cloud.spanner_v1.instance import Instance
42+
from google.cloud.spanner_v1.testing.mock_database_admin import \
43+
DatabaseAdminServicer
44+
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
45+
from google.cloud.spanner_v1.testing.mock_spanner import start_mock_server
46+
3847

3948
# Creates an aborted status with the smallest possible retry delay.
4049
def aborted_status() -> _Status:
@@ -57,6 +66,24 @@ def aborted_status() -> _Status:
5766
return status
5867

5968

69+
def _make_partial_result_sets(fields: list[tuple[str, TypeCode]],
70+
results: list[dict]) -> list[result_set.PartialResultSet]:
71+
partial_result_sets = []
72+
for result in results:
73+
partial_result_set = PartialResultSet()
74+
if len(partial_result_sets) == 0:
75+
# setting the metadata
76+
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
77+
for field in fields:
78+
metadata.row_type.fields.append(
79+
StructType.Field(name=field[0], type_=Type(code=field[1])))
80+
partial_result_set.metadata = metadata
81+
for value in result["values"]:
82+
partial_result_set.values.append(_make_value_pb(value))
83+
partial_result_set.last = result.get('last') or False
84+
partial_result_sets.append(partial_result_set)
85+
return partial_result_sets
86+
6087
# Creates an UNAVAILABLE status with the smallest possible retry delay.
6188
def unavailable_status() -> _Status:
6289
error = status_pb2.Status(
@@ -101,6 +128,11 @@ def add_select1_result():
101128
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])
102129

103130

131+
def add_execute_streaming_sql_results(sql: str,
132+
partial_result_sets: list[result_set.PartialResultSet]):
133+
MockServerTestBase.spanner_service.mock_spanner.add_execute_streaming_sql_results(
134+
sql, partial_result_sets)
135+
104136
def add_single_result(
105137
sql: str, column_name: str, type_code: spanner_type.TypeCode, row
106138
):

tests/mockserver_tests/test_basics.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,32 @@
1515
from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
1616
from google.cloud.spanner_dbapi import Connection
1717
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
18-
from google.cloud.spanner_v1 import (
19-
BatchCreateSessionsRequest,
20-
ExecuteSqlRequest,
21-
BeginTransactionRequest,
22-
TransactionOptions,
23-
ExecuteBatchDmlRequest,
24-
TypeCode,
25-
)
26-
from google.cloud.spanner_v1.transaction import Transaction
18+
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
19+
from google.cloud.spanner_v1 import BeginTransactionRequest
20+
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
21+
from google.cloud.spanner_v1 import ExecuteSqlRequest
22+
from google.cloud.spanner_v1 import TransactionOptions
23+
from google.cloud.spanner_v1 import TypeCode
2724
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
28-
29-
from tests.mockserver_tests.mock_server_test_base import (
30-
MockServerTestBase,
31-
add_select1_result,
32-
add_update_count,
33-
add_error,
34-
unavailable_status,
35-
add_single_result,
36-
)
25+
from google.cloud.spanner_v1.transaction import Transaction
26+
from tests.mockserver_tests.mock_server_test_base import MockServerTestBase
27+
from tests.mockserver_tests.mock_server_test_base import \
28+
_make_partial_result_sets
29+
from tests.mockserver_tests.mock_server_test_base import add_error
30+
from tests.mockserver_tests.mock_server_test_base import \
31+
add_execute_streaming_sql_results
32+
from tests.mockserver_tests.mock_server_test_base import add_select1_result
33+
from tests.mockserver_tests.mock_server_test_base import add_single_result
34+
from tests.mockserver_tests.mock_server_test_base import add_update_count
35+
from tests.mockserver_tests.mock_server_test_base import unavailable_status
3736

3837

3938
class TestBasics(MockServerTestBase):
39+
40+
def setUp(self):
41+
super().setUp()
42+
super().setup_class()
43+
4044
def test_select1(self):
4145
add_select1_result()
4246
with self.database.snapshot() as snapshot:
@@ -176,6 +180,27 @@ def test_last_statement_query(self):
176180
self.assertEqual(1, len(requests), msg=requests)
177181
self.assertTrue(requests[0].last_statement, requests[0])
178182

183+
def test_execute_streaming_sql_last_field(self):
184+
partial_result_sets = _make_partial_result_sets(
185+
[("ID", TypeCode.INT64), ("NAME", TypeCode.STRING)],
186+
[{"values": ["1", "ABC", "2", "DEF"]},
187+
{"values": ["3", "GHI"], "last": True}])
188+
189+
sql = "select * from my_table"
190+
add_execute_streaming_sql_results(sql, partial_result_sets)
191+
count = 1
192+
with self.database.snapshot() as snapshot:
193+
results = snapshot.execute_sql(sql)
194+
result_list = []
195+
for row in results:
196+
result_list.append(row)
197+
self.assertEqual(count, row[0])
198+
count += 1
199+
self.assertEqual(3, len(result_list))
200+
requests = self.spanner_service.requests
201+
self.assertEqual(2, len(requests), msg=requests)
202+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
203+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
179204

180205
def _execute_query(transaction: Transaction, sql: str):
181206
rows = transaction.execute_sql(sql, last_statement=True)

0 commit comments

Comments
 (0)