Skip to content

Commit fe6cb9a

Browse files
committed
Include json schema in the graph schema representation
Allow graph schema to include the json fields. This is useful for QA chain to handle queries that refers to json properties. For example, for node: Company(details = Json('market_cap', ...)) graph schema will include 'market_cap' as a json_fields of `details` in the schema representation, so that when a user ask for `get market capitalization of company`, QA chain can understand which subfield to refer to. Note: - json property schema is done via by inspecting the first non-null property. - other changes: refactor the tests, improve the logging
1 parent 6bb86ac commit fe6cb9a

2 files changed

Lines changed: 397 additions & 263 deletions

File tree

src/langchain_google_spanner/graph_store.py

Lines changed: 145 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import json
18+
import logging
1819
import re
1920
import string
2021
from abc import ABC, abstractmethod
@@ -35,6 +36,8 @@
3536
EDGE_KIND = "EDGE"
3637
USER_AGENT_GRAPH_STORE = "langchain-google-spanner-python:graphstore/" + __version__
3738

39+
logger = logging.getLogger(__name__)
40+
3841

3942
class NodeWrapper(object):
4043
"""Wrapper around Node to support set operations using node id"""
@@ -763,6 +766,69 @@ def __init__(self, node_name: str, node_keys: List[str], edge_keys: List[str]):
763766
self.edge_keys = edge_keys
764767

765768

769+
class JsonSchema(object):
770+
NODE_JSON_PROPERTY_QUERY_TEMPLATE = """
771+
GRAPH `{graph_id}`
772+
MATCH (n:`{label_name}`)
773+
WHERE n.`{property_name}` IS NOT NULL
774+
LET j = n.`{property_name}`
775+
LIMIT 1
776+
LET keys = JSON_KEYS(j, 1)
777+
FOR key IN keys
778+
LET v = j[key]
779+
LET type = json_type(v)
780+
RETURN key, type
781+
"""
782+
783+
EDGE_JSON_PROPERTY_QUERY_TEMPLATE = """
784+
GRAPH `{graph_id}`
785+
MATCH -[n:`{label_name}`]->
786+
WHERE n.`{property_name}` IS NOT NULL
787+
LET j = n.`{property_name}`
788+
LIMIT 1
789+
LET keys = JSON_KEYS(j, 1)
790+
FOR key IN keys
791+
LET v = j[key]
792+
let type = json_type(v)
793+
RETURN key, type
794+
"""
795+
796+
def __init__(self, graph_name: str, impl: SpannerInterface):
797+
self._graph_name = graph_name
798+
self._impl = impl
799+
800+
def get_node_json_property_schema(self, node_label: str, property_names: List[str]):
801+
return self._get_label_json_property_schema(
802+
node_label, property_names, self.NODE_JSON_PROPERTY_QUERY_TEMPLATE
803+
)
804+
805+
def get_edge_json_property_schema(self, edge_label: str, property_names: List[str]):
806+
return self._get_label_json_property_schema(
807+
edge_label, property_names, self.EDGE_JSON_PROPERTY_QUERY_TEMPLATE
808+
)
809+
810+
def _get_label_json_property_schema(
811+
self, label: str, property_names: List[str], query_template: str
812+
):
813+
if len(property_names) == 0:
814+
return CaseInsensitiveDict({})
815+
return CaseInsensitiveDict(
816+
{
817+
pname: [
818+
row
819+
for row in self._impl.query(
820+
query_template.format(
821+
graph_id=self._graph_name,
822+
label_name=label,
823+
property_name=pname,
824+
)
825+
)
826+
]
827+
for pname in property_names
828+
}
829+
)
830+
831+
766832
class SpannerGraphSchema(object):
767833
"""Schema representation of a property graph."""
768834

@@ -778,6 +844,7 @@ def __init__(
778844
use_flexible_schema: bool,
779845
static_node_properties: List[str] = [],
780846
static_edge_properties: List[str] = [],
847+
json_schema: Optional[JsonSchema] = None,
781848
):
782849
"""Initializes the graph schema.
783850
@@ -805,9 +872,16 @@ def __init__(
805872
self.edge_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({})
806873
self.labels: CaseInsensitiveDict[Label] = CaseInsensitiveDict({})
807874
self.properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({})
875+
self.node_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict(
876+
{}
877+
)
878+
self.edge_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict(
879+
{}
880+
)
808881
self.use_flexible_schema = use_flexible_schema
809882
self.static_node_properties = set(static_node_properties)
810883
self.static_edge_properties = set(static_edge_properties)
884+
self.json_schema = json_schema
811885

812886
def evolve(self, graph_documents: List[GraphDocument]) -> List[str]:
813887
"""Evolves current schema into a schema representing the input documents.
@@ -861,11 +935,13 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None:
861935
node_schema = ElementSchema.from_info_schema(node, decl_by_types)
862936
self._update_node_schema(node_schema)
863937
self._update_labels_and_properties(node_schema)
938+
self._update_json_property_schema(node_schema)
864939

865940
for edge in info_schema.get("edgeTables", []):
866941
edge_schema = ElementSchema.from_info_schema(edge, decl_by_types)
867942
self._update_edge_schema(edge_schema)
868943
self._update_labels_and_properties(edge_schema)
944+
self._update_json_property_schema(edge_schema)
869945

870946
def node_type_name(self, name: str) -> str:
871947
return NODE_KIND if self.use_flexible_schema else name
@@ -952,26 +1028,36 @@ def __repr__(self) -> str:
9521028
triplets_per_label.setdefault(label, []).append(
9531029
(source_node, edge, target_node)
9541030
)
1031+
1032+
def repr_property(lname, pname, ptype, json_fields):
1033+
if not json_fields:
1034+
return {"name": pname, "type": ptype}
1035+
return {"name": pname, "type": ptype, "json_fields": json_fields}
1036+
9551037
return json.dumps(
9561038
{
9571039
"Name of graph": self.graph_name,
9581040
"Node properties per node label": {
9591041
label: [
960-
{
961-
"name": name,
962-
"type": properties[name],
963-
}
964-
for name in sorted(self.labels[label].prop_names)
1042+
repr_property(
1043+
label,
1044+
pname,
1045+
properties[pname],
1046+
self.node_json_property_schema.get(label, {}).get(pname),
1047+
)
1048+
for pname in sorted(self.labels[label].prop_names)
9651049
]
9661050
for label in sorted(node_labels)
9671051
},
9681052
"Edge properties per edge label": {
9691053
label: [
970-
{
971-
"name": name,
972-
"type": properties[name],
973-
}
974-
for name in sorted(self.labels[label].prop_names)
1054+
repr_property(
1055+
label,
1056+
pname,
1057+
properties[pname],
1058+
self.edge_json_property_schema.get(label, {}).get(pname),
1059+
)
1060+
for pname in sorted(self.labels[label].prop_names)
9751061
]
9761062
for label in sorted(edge_labels)
9771063
},
@@ -1124,6 +1210,37 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
11241210
self.edges[edge_schema.name] = old_schema or edge_schema
11251211
return ddls
11261212

1213+
def _update_json_property_schema(self, element_schema: ElementSchema) -> None:
1214+
if self.json_schema is None:
1215+
return
1216+
if len(element_schema.labels) == 0:
1217+
return
1218+
lname = element_schema.labels[0]
1219+
if element_schema.kind == NODE_KIND:
1220+
json_property_schema = self.json_schema.get_node_json_property_schema(
1221+
lname,
1222+
[
1223+
pname
1224+
for pname, ptype in element_schema.types.items()
1225+
if ptype == param_types.JSON
1226+
],
1227+
)
1228+
self.node_json_property_schema.update(
1229+
{l: json_property_schema for l in element_schema.labels}
1230+
)
1231+
else:
1232+
json_property_schema = self.json_schema.get_edge_json_property_schema(
1233+
lname,
1234+
[
1235+
pname
1236+
for pname, ptype in element_schema.types.items()
1237+
if ptype == param_types.JSON
1238+
],
1239+
)
1240+
self.edge_json_property_schema.update(
1241+
{l: json_property_schema for l in element_schema.labels}
1242+
)
1243+
11271244
def _update_labels_and_properties(self, element_schema: ElementSchema) -> None:
11281245
"""Updates labels and properties based on an element schema.
11291246
@@ -1176,7 +1293,6 @@ def add_edges(
11761293
"""
11771294
edge_schema = self.get_edge_schema(self.edge_type_name(name))
11781295
if edge_schema is None:
1179-
print(list(self.edges.keys()))
11801296
raise ValueError("Unknown edge schema `%s`" % name)
11811297
for v in edge_schema.add_edges(name, edges):
11821298
yield v
@@ -1265,7 +1381,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None:
12651381
return
12661382

12671383
op = self.database.update_ddl(ddl_statements=ddls)
1268-
print("Waiting for DDL operations to complete...")
1384+
logger.info("Waiting for DDL operations to complete...")
12691385
return op.result(options.get("timeout", DEFAULT_DDL_TIMEOUT))
12701386

12711387
def insert_or_update(
@@ -1291,6 +1407,7 @@ def __init__(
12911407
static_edge_properties: List[str] = [],
12921408
impl: Optional[SpannerInterface] = None,
12931409
timeout: Optional[float] = None,
1410+
include_json_schema: bool = False,
12941411
):
12951412
"""Initializes SpannerGraphStore.
12961413
@@ -1306,7 +1423,9 @@ def __init__(
13061423
static_edge_properties: in flexible schema, treat these edge
13071424
properties as static.
13081425
timeout (Optional[float]): The timeout for queries in seconds.
1426+
include_json_schema (Optional[bool]): Whether to include json fields in the schema.
13091427
"""
1428+
self.graph_name = graph_name
13101429
self.impl = impl or SpannerImpl(
13111430
instance_id,
13121431
database_id,
@@ -1318,6 +1437,9 @@ def __init__(
13181437
use_flexible_schema,
13191438
static_node_properties,
13201439
static_edge_properties,
1440+
json_schema=(
1441+
JsonSchema(graph_name, self.impl) if include_json_schema else None
1442+
),
13211443
)
13221444

13231445
self.refresh_schema()
@@ -1345,25 +1467,28 @@ def add_graph_documents(
13451467
ddls = self.schema.evolve(graph_documents)
13461468
if ddls:
13471469
self.impl.apply_ddls(ddls)
1348-
self.refresh_schema()
13491470
else:
1350-
print("No schema change required...")
1471+
logger.info("No schema change required...")
13511472

13521473
nodes, edges = partition_graph_docs(graph_documents)
13531474
for name, elements in nodes.items():
13541475
if len(elements) == 0:
13551476
continue
13561477
for table, columns, rows in self.schema.add_nodes(name, elements):
1357-
print("Insert nodes of type `{}`...".format(name))
1478+
logger.info("Insert nodes of type `{}`...".format(name))
13581479
self.impl.insert_or_update(table, columns, rows)
13591480

13601481
for name, elements in edges.items():
13611482
if len(elements) == 0:
13621483
continue
13631484
for table, columns, rows in self.schema.add_edges(name, elements):
1364-
print("Insert edges of type `{}`...".format(name))
1485+
logger.info("Insert edges of type `{}`...".format(name))
13651486
self.impl.insert_or_update(table, columns, rows)
13661487

1488+
# Refresh schema after data insertion because json property is sampled
1489+
# over the actual data.
1490+
self.refresh_schema()
1491+
13671492
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
13681493
"""Query Spanner database.
13691494
@@ -1435,5 +1560,8 @@ def cleanup(self):
14351560
]
14361561
)
14371562
self.schema = SpannerGraphSchema(
1438-
self.schema.graph_name, self.schema.use_flexible_schema
1563+
self.schema.graph_name, self.schema.use_flexible_schema,
1564+
self.schema.static_node_properties,
1565+
self.schema.static_edge_properties,
1566+
self.schema.json_schema
14391567
)

0 commit comments

Comments
 (0)