Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 99 additions & 66 deletions src/langchain_google_spanner/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,14 @@ class ElementSchema(object):

NODE_KEY_COLUMN_NAME: str = "id"
TARGET_NODE_KEY_COLUMN_NAME: str = "target_id"

# Reserved column names when `use_flexible_schema` is true.
# Properties are stored in a JSON column named `properties`;
# Edge types are stored in a string column named `label`.
DYNAMIC_PROPERTY_COLUMN_NAME: str = "properties"
DYNAMIC_LABEL_COLUMN_NAME: str = "label"

Comment thread
mtyin marked this conversation as resolved.
name: str
original_name: str
kind: str
key_columns: List[str]
base_table_name: str
Expand Down Expand Up @@ -220,8 +223,7 @@ def make_node_schema(
node.properties = CaseInsensitiveDict({prop: prop for prop in node.types})
node.labels = [node_label]
node.base_table_name = "%s_%s" % (graph_name, node_label)
node.original_name = node_type
node.name = node.base_table_name
node.name = node_type
node.kind = NODE_KIND
node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME]
return node
Expand All @@ -243,15 +245,18 @@ def make_edge_schema(
edge.labels = [edge_label]
edge.base_table_name = "%s_%s" % (graph_schema.graph_name, edge_label)
edge.key_columns = key_columns
edge.original_name = edge_type
edge.name = edge.base_table_name
edge.name = edge_type
edge.kind = EDGE_KIND

source_node_schema = graph_schema.get_node_schema(source_node_type)
source_node_schema = graph_schema.get_node_schema(
graph_schema.node_type_name(source_node_type)
)
if source_node_schema is None:
raise ValueError("No source node schema `%s` found" % source_node_type)

target_node_schema = graph_schema.get_node_schema(target_node_type)
target_node_schema = graph_schema.get_node_schema(
graph_schema.node_type_name(target_node_type)
)
if target_node_schema is None:
raise ValueError("No target node schema `%s` found" % target_node_type)

Expand Down Expand Up @@ -346,7 +351,7 @@ def from_dynamic_nodes(
)
)
return ElementSchema.make_node_schema(
name, NODE_KIND, graph_schema.graph_name, types
NODE_KIND, NODE_KIND, graph_schema.graph_name, types
)

@staticmethod
Expand Down Expand Up @@ -452,7 +457,7 @@ def from_dynamic_edges(
)
)
return ElementSchema.make_edge_schema(
name,
EDGE_KIND,
EDGE_KIND,
graph_schema,
[
Expand Down Expand Up @@ -567,14 +572,13 @@ def add_edges(
@staticmethod
def from_info_schema(
element_schema: Dict[str, Any],
property_decls: List[Any],
decl_by_types: CaseInsensitiveDict,
) -> ElementSchema:
"""Builds ElementSchema from information schema represenation of an element.

Args:
element_schema: the information schema represenation of an element;
property_decls: the information schema represenation of property
declarations.
decl_by_types: type information of property declarations.

Returns:
ElementSchema
Expand All @@ -584,26 +588,23 @@ def from_info_schema(
"""
element = ElementSchema()
element.name = element_schema["name"]
element.original_name = element.name
element.kind = element_schema["kind"]
if element.kind not in [NODE_KIND, EDGE_KIND]:
raise ValueError("Invalid element kind `{}`".format(element.kind))

element.key_columns = element_schema["keyColumns"]
element.base_table_name = element_schema["baseTableName"]
element.labels = element_schema["labelNames"]

element.properties = CaseInsensitiveDict(
{
prop_def["propertyDeclarationName"]: prop_def["valueExpressionSql"]
for prop_def in element_schema.get("propertyDefinitions", [])
if prop_def["propertyDeclarationName"] in decl_by_types
}
)
element.types = CaseInsensitiveDict(
{
decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"])
for decl in property_decls
if decl["name"] in element.properties
}
{decl: decl_by_types[decl] for decl in element.properties.keys()}
)

if element.kind == EDGE_KIND:
Expand Down Expand Up @@ -636,7 +637,7 @@ def to_ddl(self, graph_schema: SpannerGraphSchema) -> str:
to_identifiers = GraphDocumentUtility.to_identifiers

def get_reference_node_table(name: str) -> str:
node_schema = graph_schema.node_tables.get(name, None)
node_schema = graph_schema.nodes.get(name, None)
if node_schema is None:
raise ValueError("No node schema `%s` found" % name)
return node_schema.base_table_name
Expand Down Expand Up @@ -708,13 +709,17 @@ def evolve(self, new_schema: ElementSchema) -> List[str]:
)
)

for k, v in new_schema.properties.items():
if k in self.properties:
if self.properties[k].casefold() != v.casefold():
raise ValueError(
"Property with name `{}` should have the same definition, got {},"
" expected {}".format(k, v, self.properties[k])
)
# Only validate property definition when they're the same definition,
# don't validate when two different definitions are based on the same
# underlying table.
if self.name == new_schema.name:
for k, v in new_schema.properties.items():
if k in self.properties:
if self.properties[k].casefold() != v.casefold():
raise ValueError(
"Property with name `{}` should have the same definition, got {},"
" expected {}".format(k, v, self.properties[k])
)
Comment thread
mtyin marked this conversation as resolved.

for k, v in new_schema.types.items():
if k in self.types:
Expand Down Expand Up @@ -845,16 +850,29 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None:
info_schema: the information schema represenation of a graph;
"""
property_decls = info_schema.get("propertyDeclarations", [])
decl_by_types = CaseInsensitiveDict(
{
decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"])
for decl in property_decls
if TypeUtility.schema_str_to_spanner_type(decl["type"]) is not None
}
)
for node in info_schema["nodeTables"]:
node_schema = ElementSchema.from_info_schema(node, property_decls)
node_schema = ElementSchema.from_info_schema(node, decl_by_types)
self._update_node_schema(node_schema)
self._update_labels_and_properties(node_schema)

for edge in info_schema.get("edgeTables", []):
edge_schema = ElementSchema.from_info_schema(edge, property_decls)
edge_schema = ElementSchema.from_info_schema(edge, decl_by_types)
self._update_edge_schema(edge_schema)
self._update_labels_and_properties(edge_schema)

def node_type_name(self, name: str) -> str:
return NODE_KIND if self.use_flexible_schema else name

def edge_type_name(self, name: str) -> str:
return EDGE_KIND if self.use_flexible_schema else name

def get_node_schema(self, name: str) -> Optional[ElementSchema]:
"""Gets the node schema by name.

Expand Down Expand Up @@ -919,40 +937,54 @@ def __repr__(self) -> str:
for k, v in self.properties.items()
}
)
node_labels = {label for node in self.nodes.values() for label in node.labels}
edge_labels = {label for edge in self.edges.values() for label in edge.labels}
Triplet = Tuple[ElementSchema, ElementSchema, ElementSchema]
triplets_per_label: CaseInsensitiveDict[List[Triplet]] = CaseInsensitiveDict({})
for edge in self.edges.values():
for label in edge.labels:
source_node = self.get_node_schema(edge.source.node_name)
target_node = self.get_node_schema(edge.target.node_name)
if source_node is None:
raise ValueError(f"Source node {edge.source.node_name} not found")
if target_node is None:
raise ValueError(f"Tource node {edge.target.node_name} not found")
triplets_per_label.setdefault(label, []).append(
(source_node, edge, target_node)
)
return json.dumps(
{
"Name of graph": self.graph_name,
"Node properties per node type": {
node.name: [
"Node properties per node label": {
label: [
{
"property name": name,
"property type": properties[name],
}
for name in node.properties.keys()
for name in self.labels[label].prop_names
]
for node in self.nodes.values()
for label in node_labels
},
"Edge properties per edge type": {
edge.name: [
"Edge properties per edge label": {
label: [
{
"property name": name,
"property type": properties[name],
}
for name in edge.properties.keys()
for name in self.labels[label].prop_names
]
for edge in self.edges.values()
},
"Node labels per node type": {
node.name: node.labels for node in self.nodes.values()
for label in edge_labels
},
"Edge labels per edge type": {
edge.name: edge.labels for edge in self.edges.values()
},
"Edges": {
edge.name: "From {} nodes to {} nodes".format(
edge.source.node_name, edge.target.node_name
)
for edge in self.edges.values()
"Possible edges per label": {
label: [
"(:{}) -[:{}]-> (:{})".format(
source_node_label, label, target_node_label
)
for (source, edge, target) in triplets
for source_node_label in source.labels
for target_node_label in target.labels
]
for label, triplets in triplets_per_label.items()
},
},
indent=2,
Expand Down Expand Up @@ -1035,18 +1067,15 @@ def construct_element_table(
)
ddl += "\nNODE TABLES(\n "
ddl += ",\n ".join(
(
construct_element_table(node, self.labels)
for node in self.node_tables.values()
)
(construct_element_table(node, self.labels) for node in self.nodes.values())
)
ddl += "\n)"
if len(self.edges) > 0:
ddl += "\nEDGE TABLES(\n "
ddl += ",\n ".join(
(
construct_element_table(edge, self.labels)
for edge in self.edge_tables.values()
for edge in self.edges.values()
)
)
ddl += "\n)"
Expand All @@ -1062,14 +1091,16 @@ def _update_node_schema(self, node_schema: ElementSchema) -> List[str]:
List[str]: a list of DDL statements that requires to evolve the schema.
"""

old_schema = self.node_tables.get(node_schema.name, None)
if old_schema is None:
ddls = [node_schema.to_ddl(self)]
self.node_tables[node_schema.name] = node_schema
else:
old_schema = self.nodes.get(node_schema.name, None)
if old_schema is not None:
ddls = old_schema.evolve(node_schema)
elif node_schema.base_table_name in self.node_tables:
ddls = self.node_tables[node_schema.base_table_name].evolve(node_schema)
else:
ddls = [node_schema.to_ddl(self)]
self.node_tables[node_schema.base_table_name] = node_schema

self.nodes[node_schema.original_name] = old_schema or node_schema
self.nodes[node_schema.name] = old_schema or node_schema
return ddls

def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
Expand All @@ -1081,15 +1112,16 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
Returns:
List[str]: a list of DDL statements that requires to evolve the schema.
"""
if edge_schema.base_table_name not in self.edge_tables:
old_schema = self.edges.get(edge_schema.name, None)
if old_schema is not None:
ddls = old_schema.evolve(edge_schema)
elif edge_schema.base_table_name in self.edge_tables:
ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema)
else:
ddls = [edge_schema.to_ddl(self)]
self.edge_tables[edge_schema.base_table_name] = edge_schema
else:
ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema)

self.edges[edge_schema.original_name] = self.edge_tables[
edge_schema.base_table_name
]
self.edges[edge_schema.name] = old_schema or edge_schema
return ddls

def _update_labels_and_properties(self, element_schema: ElementSchema) -> None:
Expand Down Expand Up @@ -1121,7 +1153,7 @@ def add_nodes(
List[str]: a list of column names;
List[List[Any]]: a list of rows.
"""
node_schema = self.get_node_schema(name)
node_schema = self.get_node_schema(self.node_type_name(name))
if node_schema is None:
raise ValueError("Unknown node schema: `%s`" % name)
for v in node_schema.add_nodes(name, nodes):
Expand All @@ -1142,8 +1174,9 @@ def add_edges(
List[str]: a list of column names;
List[List[Any]]: a list of rows.
"""
edge_schema = self.get_edge_schema(name)
edge_schema = self.get_edge_schema(self.edge_type_name(name))
if edge_schema is None:
print(list(self.edges.keys()))
raise ValueError("Unknown edge schema `%s`" % name)
for v in edge_schema.add_edges(name, edges):
yield v
Expand Down
10 changes: 7 additions & 3 deletions src/langchain_google_spanner/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import base64
import datetime
from typing import Any
from typing import Any, Optional

from google.cloud.spanner_v1 import JsonObject, param_types

Expand Down Expand Up @@ -67,14 +67,14 @@ def spanner_type_to_schema_str(
raise ValueError("Unsupported type: %s" % t)

@staticmethod
def schema_str_to_spanner_type(s: str) -> param_types.Type:
def schema_str_to_spanner_type(s: str) -> Optional[param_types.Type]:
"""Returns a Spanner type corresponding to the string representation from Spanner schema type.

Parameters:
- s: string representation of a Spanner schema type.

Returns:
- Type[Any]: the corresponding Spanner type.
- Optional[param_types.Type]: the corresponding Spanner type.
"""
if s == "BOOL":
return param_types.BOOL
Expand All @@ -98,6 +98,10 @@ def schema_str_to_spanner_type(s: str) -> param_types.Type:
return param_types.Array(
TypeUtility.schema_str_to_spanner_type(s[len("ARRAY<") : -len(">")])
)
if s == "TOKENLIST":
# There is no corresponding type for TOKENLIST in value type yet.
# Returns none to allow TOKENLIST in the schema.
return None
raise ValueError("Unsupported type: %s" % s)

@staticmethod
Expand Down
Loading