Skip to content

Commit 194f3e3

Browse files
smukilSailesh MukilSailesh Mukil Gangatharan
authored
GraphServer: Remove reliance on CloudSpanner's TypeCode (#48)
* Standardize return type for execute_query() functions - Also removes cloud-spanner specific fields from the return type. Specifically `StructType.Field` is removed from the return type. Removing this tightly coupled logic is required to allow new DB implementations. * Abstract away "database" implementations and remove strong coupling of DB implementation with APIs 1. Abstracts SpannerDatabase with clear APIs 2. Introduces CloudSpannerDatabase as an implementation of SpannerDatabase 3. Removes further tight coupling with the cloud spanner client by adding a SpannerFieldInfo dataclass to replace usage of StructType.Field * Introduce exec_env.py to maintain global state + minor bug fix 1. The global database_instances is moved to exec_env.py to avoid circular imports. 2. SpannerFiledInfo.typename populated with the correct name now 3. Remove all cloud spanner refs from database.py * GraphServer: Remove reliance on CloudSpanner's TypeCode The google.cloud.spanner.TypeCode was used to validate supported types for properties. However, the type information is always received as a string from the JS client. It is only internally converted to a TypeCode for verification. Comparing against the TypeCode does not make the validation any more robust than just confirming if the strings themselves are valid type strings. Removing reliance on TypeCode also makes GraphServer database agnostic. * SpannerQueryResult: Change 'error' to 'err' All consumers of SpannerQueryResult expect `err` instead of `error`. Without this fix, failures are silent. With this, they're displayed within the notebook. Thanks to @cqian23 for pointing this out --------- Co-authored-by: Sailesh Mukil <mukil.sailesh@gmail.com> Co-authored-by: Sailesh Mukil Gangatharan <saileshmukil@google.com>
1 parent 9821745 commit 194f3e3

10 files changed

Lines changed: 390 additions & 257 deletions

spanner_graphs/cloud_database.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2024 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module contains the cloud-specific implementation for talking to a Spanner database.
17+
"""
18+
19+
from __future__ import annotations
20+
import json
21+
from typing import Any, Dict, List, Tuple
22+
23+
from google.cloud import spanner
24+
from google.cloud.spanner_v1 import JsonObject
25+
from google.api_core.client_options import ClientOptions
26+
from google.cloud.spanner_v1.types import StructType, Type, TypeCode
27+
import pydata_google_auth
28+
29+
from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo
30+
31+
def _get_default_credentials_with_project():
32+
return pydata_google_auth.default(
33+
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
34+
35+
def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]:
36+
"""Converts a list of StructType.Field to a list of SpannerFieldInfo."""
37+
return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields]
38+
39+
class CloudSpannerDatabase(SpannerDatabase):
40+
"""Concrete implementation for Spanner database on the cloud."""
41+
def __init__(self, project_id: str, instance_id: str,
42+
database_id: str) -> None:
43+
credentials, _ = _get_default_credentials_with_project()
44+
self.client = spanner.Client(
45+
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
46+
self.instance = self.client.instance(instance_id)
47+
self.database = self.instance.database(database_id)
48+
self.schema_json: Any | None = None
49+
50+
def __repr__(self) -> str:
51+
return (f"<CloudSpannerDatabase["
52+
f"project:{self.client.project_name},"
53+
f"instance:{self.instance.name},"
54+
f"db:{self.database.name}]>")
55+
56+
def _extract_graph_name(self, query: str) -> str:
57+
words = query.strip().split()
58+
if len(words) < 3:
59+
raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)")
60+
61+
if words[0].upper() != "GRAPH":
62+
raise ValueError("invalid query: GRAPH must be the first word")
63+
64+
return words[1]
65+
66+
def _get_schema_for_graph(self, graph_query: str) -> Any | None:
67+
try:
68+
graph_name = self._extract_graph_name(graph_query)
69+
except ValueError:
70+
return None
71+
72+
with self.database.snapshot() as snapshot:
73+
schema_query = """
74+
SELECT property_graph_name, property_graph_metadata_json
75+
FROM information_schema.property_graphs
76+
WHERE property_graph_name = @graph_name
77+
"""
78+
params = {"graph_name": graph_name}
79+
param_type = {"graph_name": spanner.param_types.STRING}
80+
81+
result = snapshot.execute_sql(schema_query, params=params, param_types=param_type)
82+
schema_rows = list(result)
83+
84+
if schema_rows:
85+
return schema_rows[0][1]
86+
else:
87+
return None
88+
89+
def execute_query(
90+
self,
91+
query: str,
92+
limit: int = None,
93+
is_test_query: bool = False,
94+
) -> SpannerQueryResult:
95+
"""
96+
This method executes the provided `query`
97+
98+
Args:
99+
query: The SQL query to execute against the database
100+
limit: An optional limit for the number of rows to return
101+
is_test_query: If true, skips schema fetching for graph queries.
102+
103+
Returns:
104+
A `SpannerQueryResult`
105+
"""
106+
self.schema_json = None
107+
if not is_test_query:
108+
self.schema_json = self._get_schema_for_graph(query)
109+
110+
with self.database.snapshot() as snapshot:
111+
params = None
112+
param_types = None
113+
if limit and limit > 0:
114+
params = dict(limit=limit)
115+
116+
try:
117+
results = snapshot.execute_sql(query, params=params, param_types=param_types)
118+
rows = list(results)
119+
except Exception as e:
120+
return SpannerQueryResult(
121+
data={},
122+
fields=[],
123+
rows=[],
124+
schema_json=self.schema_json,
125+
err=e
126+
)
127+
128+
fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields)
129+
data = {field.name: [] for field in fields}
130+
131+
if len(fields) == 0:
132+
return SpannerQueryResult(
133+
data=data,
134+
fields=fields,
135+
rows=rows,
136+
schema_json=self.schema_json,
137+
err=None
138+
)
139+
140+
for row_data in rows:
141+
for field, value in zip(fields, row_data):
142+
if isinstance(value, JsonObject):
143+
data[field.name].append(json.loads(value.serialize()))
144+
else:
145+
data[field.name].append(value)
146+
147+
return SpannerQueryResult(
148+
data=data,
149+
fields=fields,
150+
rows=rows,
151+
schema_json=self.schema_json,
152+
err=None
153+
)

spanner_graphs/conversion.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from typing import Any, List, Dict, Tuple
2222
import json
2323

24-
from google.cloud.spanner_v1.types import TypeCode, StructType
25-
24+
from spanner_graphs.database import SpannerFieldInfo
2625
from spanner_graphs.graph_entities import Node, Edge
2726
from spanner_graphs.schema_manager import SchemaManager
2827

29-
def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]:
28+
def get_nodes_edges(data: Dict[str, List[Any]], fields: List[SpannerFieldInfo], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]:
3029
schema_manager = SchemaManager(schema_json)
3130
nodes: List[Node] = []
3231
edges: List[Edge] = []
@@ -37,15 +36,15 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field],
3736
for field in fields:
3837
column_name = field.name
3938
column_data = data[column_name]
40-
39+
4140
# Only process JSON and Array of JSON types
42-
if field.type_.code not in [TypeCode.JSON, TypeCode.ARRAY]:
41+
if field.typename not in ["JSON", "ARRAY"]:
4342
continue
4443

4544
# Process each value in the column
4645
for value in column_data:
4746
items_to_process = []
48-
47+
4948
# Handle both single JSON and arrays of JSON
5049
if isinstance(value, list):
5150
items_to_process.extend(value)
@@ -92,4 +91,4 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field],
9291
nodes.append(Node.make_intermediate(identifier))
9392
node_identifiers.add(identifier)
9493

95-
return nodes, edges
94+
return nodes, edges

0 commit comments

Comments
 (0)