diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index a4d857e..b8cefc0 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -188,11 +188,14 @@ class ElementSchema(object): NODE_KEY_COLUMN_NAME: str = "id" TARGET_NODE_KEY_COLUMN_NAME: str = "target_id" + + # Reserved column names when `use_flexible_schema` is true. + # Properties are stored in a JSON column named `properties`; + # Edge types are stored in a string column named `label`. DYNAMIC_PROPERTY_COLUMN_NAME: str = "properties" DYNAMIC_LABEL_COLUMN_NAME: str = "label" name: str - original_name: str kind: str key_columns: List[str] base_table_name: str @@ -220,8 +223,7 @@ def make_node_schema( node.properties = CaseInsensitiveDict({prop: prop for prop in node.types}) node.labels = [node_label] node.base_table_name = "%s_%s" % (graph_name, node_label) - node.original_name = node_type - node.name = node.base_table_name + node.name = node_type node.kind = NODE_KIND node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME] return node @@ -243,15 +245,18 @@ def make_edge_schema( edge.labels = [edge_label] edge.base_table_name = "%s_%s" % (graph_schema.graph_name, edge_label) edge.key_columns = key_columns - edge.original_name = edge_type - edge.name = edge.base_table_name + edge.name = edge_type edge.kind = EDGE_KIND - source_node_schema = graph_schema.get_node_schema(source_node_type) + source_node_schema = graph_schema.get_node_schema( + graph_schema.node_type_name(source_node_type) + ) if source_node_schema is None: raise ValueError("No source node schema `%s` found" % source_node_type) - target_node_schema = graph_schema.get_node_schema(target_node_type) + target_node_schema = graph_schema.get_node_schema( + graph_schema.node_type_name(target_node_type) + ) if target_node_schema is None: raise ValueError("No target node schema `%s` found" % target_node_type) @@ -346,7 +351,7 @@ def from_dynamic_nodes( ) ) return ElementSchema.make_node_schema( - name, NODE_KIND, graph_schema.graph_name, types + NODE_KIND, NODE_KIND, graph_schema.graph_name, types ) @staticmethod @@ -452,7 +457,7 @@ def from_dynamic_edges( ) ) return ElementSchema.make_edge_schema( - name, + EDGE_KIND, EDGE_KIND, graph_schema, [ @@ -567,14 +572,13 @@ def add_edges( @staticmethod def from_info_schema( element_schema: Dict[str, Any], - property_decls: List[Any], + decl_by_types: CaseInsensitiveDict, ) -> ElementSchema: """Builds ElementSchema from information schema represenation of an element. Args: element_schema: the information schema represenation of an element; - property_decls: the information schema represenation of property - declarations. + decl_by_types: type information of property declarations. Returns: ElementSchema @@ -584,7 +588,6 @@ def from_info_schema( """ element = ElementSchema() element.name = element_schema["name"] - element.original_name = element.name element.kind = element_schema["kind"] if element.kind not in [NODE_KIND, EDGE_KIND]: raise ValueError("Invalid element kind `{}`".format(element.kind)) @@ -592,18 +595,16 @@ def from_info_schema( element.key_columns = element_schema["keyColumns"] element.base_table_name = element_schema["baseTableName"] element.labels = element_schema["labelNames"] + element.properties = CaseInsensitiveDict( { prop_def["propertyDeclarationName"]: prop_def["valueExpressionSql"] for prop_def in element_schema.get("propertyDefinitions", []) + if prop_def["propertyDeclarationName"] in decl_by_types } ) element.types = CaseInsensitiveDict( - { - decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"]) - for decl in property_decls - if decl["name"] in element.properties - } + {decl: decl_by_types[decl] for decl in element.properties.keys()} ) if element.kind == EDGE_KIND: @@ -636,7 +637,7 @@ def to_ddl(self, graph_schema: SpannerGraphSchema) -> str: to_identifiers = GraphDocumentUtility.to_identifiers def get_reference_node_table(name: str) -> str: - node_schema = graph_schema.node_tables.get(name, None) + node_schema = graph_schema.nodes.get(name, None) if node_schema is None: raise ValueError("No node schema `%s` found" % name) return node_schema.base_table_name @@ -708,13 +709,17 @@ def evolve(self, new_schema: ElementSchema) -> List[str]: ) ) - for k, v in new_schema.properties.items(): - if k in self.properties: - if self.properties[k].casefold() != v.casefold(): - raise ValueError( - "Property with name `{}` should have the same definition, got {}," - " expected {}".format(k, v, self.properties[k]) - ) + # Only validate property definition when they're the same definition, + # don't validate when two different definitions are based on the same + # underlying table. + if self.name == new_schema.name: + for k, v in new_schema.properties.items(): + if k in self.properties: + if self.properties[k].casefold() != v.casefold(): + raise ValueError( + "Property with name `{}` should have the same definition, got {}," + " expected {}".format(k, v, self.properties[k]) + ) for k, v in new_schema.types.items(): if k in self.types: @@ -845,16 +850,29 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None: info_schema: the information schema represenation of a graph; """ property_decls = info_schema.get("propertyDeclarations", []) + decl_by_types = CaseInsensitiveDict( + { + decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"]) + for decl in property_decls + if TypeUtility.schema_str_to_spanner_type(decl["type"]) is not None + } + ) for node in info_schema["nodeTables"]: - node_schema = ElementSchema.from_info_schema(node, property_decls) + node_schema = ElementSchema.from_info_schema(node, decl_by_types) self._update_node_schema(node_schema) self._update_labels_and_properties(node_schema) for edge in info_schema.get("edgeTables", []): - edge_schema = ElementSchema.from_info_schema(edge, property_decls) + edge_schema = ElementSchema.from_info_schema(edge, decl_by_types) self._update_edge_schema(edge_schema) self._update_labels_and_properties(edge_schema) + def node_type_name(self, name: str) -> str: + return NODE_KIND if self.use_flexible_schema else name + + def edge_type_name(self, name: str) -> str: + return EDGE_KIND if self.use_flexible_schema else name + def get_node_schema(self, name: str) -> Optional[ElementSchema]: """Gets the node schema by name. @@ -919,40 +937,54 @@ def __repr__(self) -> str: for k, v in self.properties.items() } ) + node_labels = {label for node in self.nodes.values() for label in node.labels} + edge_labels = {label for edge in self.edges.values() for label in edge.labels} + Triplet = Tuple[ElementSchema, ElementSchema, ElementSchema] + triplets_per_label: CaseInsensitiveDict[List[Triplet]] = CaseInsensitiveDict({}) + for edge in self.edges.values(): + for label in edge.labels: + source_node = self.get_node_schema(edge.source.node_name) + target_node = self.get_node_schema(edge.target.node_name) + if source_node is None: + raise ValueError(f"Source node {edge.source.node_name} not found") + if target_node is None: + raise ValueError(f"Tource node {edge.target.node_name} not found") + triplets_per_label.setdefault(label, []).append( + (source_node, edge, target_node) + ) return json.dumps( { "Name of graph": self.graph_name, - "Node properties per node type": { - node.name: [ + "Node properties per node label": { + label: [ { "property name": name, "property type": properties[name], } - for name in node.properties.keys() + for name in self.labels[label].prop_names ] - for node in self.nodes.values() + for label in node_labels }, - "Edge properties per edge type": { - edge.name: [ + "Edge properties per edge label": { + label: [ { "property name": name, "property type": properties[name], } - for name in edge.properties.keys() + for name in self.labels[label].prop_names ] - for edge in self.edges.values() - }, - "Node labels per node type": { - node.name: node.labels for node in self.nodes.values() + for label in edge_labels }, - "Edge labels per edge type": { - edge.name: edge.labels for edge in self.edges.values() - }, - "Edges": { - edge.name: "From {} nodes to {} nodes".format( - edge.source.node_name, edge.target.node_name - ) - for edge in self.edges.values() + "Possible edges per label": { + label: [ + "(:{}) -[:{}]-> (:{})".format( + source_node_label, label, target_node_label + ) + for (source, edge, target) in triplets + for source_node_label in source.labels + for target_node_label in target.labels + ] + for label, triplets in triplets_per_label.items() }, }, indent=2, @@ -1035,10 +1067,7 @@ def construct_element_table( ) ddl += "\nNODE TABLES(\n " ddl += ",\n ".join( - ( - construct_element_table(node, self.labels) - for node in self.node_tables.values() - ) + (construct_element_table(node, self.labels) for node in self.nodes.values()) ) ddl += "\n)" if len(self.edges) > 0: @@ -1046,7 +1075,7 @@ def construct_element_table( ddl += ",\n ".join( ( construct_element_table(edge, self.labels) - for edge in self.edge_tables.values() + for edge in self.edges.values() ) ) ddl += "\n)" @@ -1062,14 +1091,16 @@ def _update_node_schema(self, node_schema: ElementSchema) -> List[str]: List[str]: a list of DDL statements that requires to evolve the schema. """ - old_schema = self.node_tables.get(node_schema.name, None) - if old_schema is None: - ddls = [node_schema.to_ddl(self)] - self.node_tables[node_schema.name] = node_schema - else: + old_schema = self.nodes.get(node_schema.name, None) + if old_schema is not None: ddls = old_schema.evolve(node_schema) + elif node_schema.base_table_name in self.node_tables: + ddls = self.node_tables[node_schema.base_table_name].evolve(node_schema) + else: + ddls = [node_schema.to_ddl(self)] + self.node_tables[node_schema.base_table_name] = node_schema - self.nodes[node_schema.original_name] = old_schema or node_schema + self.nodes[node_schema.name] = old_schema or node_schema return ddls def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: @@ -1081,15 +1112,16 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: Returns: List[str]: a list of DDL statements that requires to evolve the schema. """ - if edge_schema.base_table_name not in self.edge_tables: + old_schema = self.edges.get(edge_schema.name, None) + if old_schema is not None: + ddls = old_schema.evolve(edge_schema) + elif edge_schema.base_table_name in self.edge_tables: + ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema) + else: ddls = [edge_schema.to_ddl(self)] self.edge_tables[edge_schema.base_table_name] = edge_schema - else: - ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema) - self.edges[edge_schema.original_name] = self.edge_tables[ - edge_schema.base_table_name - ] + self.edges[edge_schema.name] = old_schema or edge_schema return ddls def _update_labels_and_properties(self, element_schema: ElementSchema) -> None: @@ -1121,7 +1153,7 @@ def add_nodes( List[str]: a list of column names; List[List[Any]]: a list of rows. """ - node_schema = self.get_node_schema(name) + node_schema = self.get_node_schema(self.node_type_name(name)) if node_schema is None: raise ValueError("Unknown node schema: `%s`" % name) for v in node_schema.add_nodes(name, nodes): @@ -1142,8 +1174,9 @@ def add_edges( List[str]: a list of column names; List[List[Any]]: a list of rows. """ - edge_schema = self.get_edge_schema(name) + edge_schema = self.get_edge_schema(self.edge_type_name(name)) if edge_schema is None: + print(list(self.edges.keys())) raise ValueError("Unknown edge schema `%s`" % name) for v in edge_schema.add_edges(name, edges): yield v diff --git a/src/langchain_google_spanner/type_utils.py b/src/langchain_google_spanner/type_utils.py index 66a2a60..485e88e 100644 --- a/src/langchain_google_spanner/type_utils.py +++ b/src/langchain_google_spanner/type_utils.py @@ -16,7 +16,7 @@ import base64 import datetime -from typing import Any +from typing import Any, Optional from google.cloud.spanner_v1 import JsonObject, param_types @@ -67,14 +67,14 @@ def spanner_type_to_schema_str( raise ValueError("Unsupported type: %s" % t) @staticmethod - def schema_str_to_spanner_type(s: str) -> param_types.Type: + def schema_str_to_spanner_type(s: str) -> Optional[param_types.Type]: """Returns a Spanner type corresponding to the string representation from Spanner schema type. Parameters: - s: string representation of a Spanner schema type. Returns: - - Type[Any]: the corresponding Spanner type. + - Optional[param_types.Type]: the corresponding Spanner type. """ if s == "BOOL": return param_types.BOOL @@ -98,6 +98,10 @@ def schema_str_to_spanner_type(s: str) -> param_types.Type: return param_types.Array( TypeUtility.schema_str_to_spanner_type(s[len("ARRAY<") : -len(">")]) ) + if s == "TOKENLIST": + # There is no corresponding type for TOKENLIST in value type yet. + # Returns none to allow TOKENLIST in the schema. + return None raise ValueError("Unsupported type: %s" % s) @staticmethod diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index e9013f4..271dba3 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -14,6 +14,7 @@ import base64 import datetime +import json import os import random import string @@ -399,7 +400,6 @@ def test_spanner_graph_avoid_unnecessary_overwrite(self, use_flexible_schema): finally: print("Clean up graph with name `{}`".format(graph_name)) graph.cleanup() - print("Actual results:", results) @pytest.mark.parametrize( "graph_name, raises_exception", @@ -433,3 +433,93 @@ def test_spanner_graph_invalid_graph_name(self, graph_name, raises_exception): static_node_properties=["a", "b"], static_edge_properties=["a", "b"], ) + + @pytest.mark.parametrize("use_flexible_schema", [False, True]) + def test_spanner_graph_with_existing_graph(self, use_flexible_schema): + suffix = random_string(num_char=5, exclude_whitespaces=True) + graph_name = "test_graph{}".format(suffix) + node_table_name = "{}_node".format(graph_name) + edge_table_name = "{}_edge".format(graph_name) + graph = SpannerGraphStore( + instance_id, + google_database, + graph_name, + client=Client(project=project_id), + use_flexible_schema=use_flexible_schema, + ) + graph.refresh_schema() + try: + graph.impl.apply_ddls( + [ + f""" + CREATE TABLE IF NOT EXISTS {node_table_name} ( + id INT64 NOT NULL, + str STRING(MAX), + token TOKENLIST AS (TOKENIZE_FULLTEXT(str)) HIDDEN, + ) PRIMARY KEY (id) + """, + f""" + CREATE TABLE IF NOT EXISTS {edge_table_name} ( + id INT64 NOT NULL, + target_id INT64 NOT NULL, + ) PRIMARY KEY (id, target_id) + """, + f""" + CREATE PROPERTY GRAPH IF NOT EXISTS {graph_name} + NODE TABLES ( + {node_table_name} AS NodeA + LABEL Node + LABEL NodeA PROPERTIES(id, id AS node_a_id), + {node_table_name} AS NodeB + LABEL Node + LABEL NodeB PROPERTIES(id, id AS node_b_id) + ) + EDGE TABLES ( + {edge_table_name} AS EdgeAB + SOURCE KEY(id) REFERENCES NodeA + DESTINATION KEY(target_id) REFERENCES NodeB + LABEL Edge PROPERTIES(id AS source_id, target_id AS dest_id) + LABEL EdgeAB PROPERTIES(id AS node_a_id, target_id AS node_b_id), + {edge_table_name} AS EdgeBA + SOURCE KEY(id) REFERENCES NodeB + DESTINATION KEY(target_id) REFERENCES NodeA + LABEL Edge PROPERTIES(id AS source_id, target_id AS dest_id) + LABEL EdgeBA PROPERTIES(target_id AS node_a_id, id AS node_b_id), + ) + """, + ] + ) + graph.refresh_schema() + schema = json.loads(graph.get_schema) + edgeab = graph.schema.get_edge_schema("EdgeAB") + edgeba = graph.schema.get_edge_schema("EdgeBA") + assert (edgeab.source.node_name, edgeab.target.node_name) == ( + "NodeA", + "NodeB", + ) + assert (edgeba.source.node_name, edgeba.target.node_name) == ( + "NodeB", + "NodeA", + ) + # TOKENLIST-typed properties are ignored. + assert len(schema["Node properties per node label"]["Node"]) == 4, schema[ + "Node properties per node label" + ]["Node"] + assert len(schema["Node properties per node label"]["NodeA"]) == 3, schema[ + "Node properties per node label" + ]["NodeA"] + assert len(schema["Node properties per node label"]["NodeB"]) == 3, schema[ + "Node properties per node label" + ]["NodB"] + assert len(schema["Possible edges per label"]["EdgeAB"]) == 4, schema[ + "Possible edges per label" + ]["EdgeAB"] + assert len(schema["Possible edges per label"]["EdgeBA"]) == 4, schema[ + "Possible edges per label" + ]["EdgeBA"] + assert len(schema["Possible edges per label"]["Edge"]) == 8, schema[ + "Possible edges per label" + ]["Edge"] + finally: + print("Clean up graph with name `{}`".format(graph_name)) + graph.cleanup()