1515from __future__ import annotations
1616
1717import json
18+ import logging
1819import re
1920import string
2021from abc import ABC , abstractmethod
3536EDGE_KIND = "EDGE"
3637USER_AGENT_GRAPH_STORE = "langchain-google-spanner-python:graphstore/" + __version__
3738
39+ logger = logging .getLogger (__name__ )
40+
3841
3942class NodeWrapper (object ):
4043 """Wrapper around Node to support set operations using node id"""
@@ -763,6 +766,69 @@ def __init__(self, node_name: str, node_keys: List[str], edge_keys: List[str]):
763766 self .edge_keys = edge_keys
764767
765768
769+ class JsonSchema (object ):
770+ NODE_JSON_PROPERTY_QUERY_TEMPLATE = """
771+ GRAPH `{graph_id}`
772+ MATCH (n:`{label_name}`)
773+ WHERE n.`{property_name}` IS NOT NULL
774+ LET j = n.`{property_name}`
775+ LIMIT 1
776+ LET keys = JSON_KEYS(j, 1)
777+ FOR key IN keys
778+ LET v = j[key]
779+ LET type = json_type(v)
780+ RETURN key, type
781+ """
782+
783+ EDGE_JSON_PROPERTY_QUERY_TEMPLATE = """
784+ GRAPH `{graph_id}`
785+ MATCH -[n:`{label_name}`]->
786+ WHERE n.`{property_name}` IS NOT NULL
787+ LET j = n.`{property_name}`
788+ LIMIT 1
789+ LET keys = JSON_KEYS(j, 1)
790+ FOR key IN keys
791+ LET v = j[key]
792+ let type = json_type(v)
793+ RETURN key, type
794+ """
795+
796+ def __init__ (self , graph_name : str , impl : SpannerInterface ):
797+ self ._graph_name = graph_name
798+ self ._impl = impl
799+
800+ def get_node_json_property_schema (self , node_label : str , property_names : List [str ]):
801+ return self ._get_label_json_property_schema (
802+ node_label , property_names , self .NODE_JSON_PROPERTY_QUERY_TEMPLATE
803+ )
804+
805+ def get_edge_json_property_schema (self , edge_label : str , property_names : List [str ]):
806+ return self ._get_label_json_property_schema (
807+ edge_label , property_names , self .EDGE_JSON_PROPERTY_QUERY_TEMPLATE
808+ )
809+
810+ def _get_label_json_property_schema (
811+ self , label : str , property_names : List [str ], query_template : str
812+ ):
813+ if len (property_names ) == 0 :
814+ return CaseInsensitiveDict ({})
815+ return CaseInsensitiveDict (
816+ {
817+ pname : [
818+ row
819+ for row in self ._impl .query (
820+ query_template .format (
821+ graph_id = self ._graph_name ,
822+ label_name = label ,
823+ property_name = pname ,
824+ )
825+ )
826+ ]
827+ for pname in property_names
828+ }
829+ )
830+
831+
766832class SpannerGraphSchema (object ):
767833 """Schema representation of a property graph."""
768834
@@ -778,6 +844,7 @@ def __init__(
778844 use_flexible_schema : bool ,
779845 static_node_properties : List [str ] = [],
780846 static_edge_properties : List [str ] = [],
847+ json_schema : Optional [JsonSchema ] = None ,
781848 ):
782849 """Initializes the graph schema.
783850
@@ -805,9 +872,16 @@ def __init__(
805872 self .edge_tables : CaseInsensitiveDict [ElementSchema ] = CaseInsensitiveDict ({})
806873 self .labels : CaseInsensitiveDict [Label ] = CaseInsensitiveDict ({})
807874 self .properties : CaseInsensitiveDict [param_types .Type ] = CaseInsensitiveDict ({})
875+ self .node_json_property_schema : CaseInsensitiveDict [Dict ] = CaseInsensitiveDict (
876+ {}
877+ )
878+ self .edge_json_property_schema : CaseInsensitiveDict [Dict ] = CaseInsensitiveDict (
879+ {}
880+ )
808881 self .use_flexible_schema = use_flexible_schema
809882 self .static_node_properties = set (static_node_properties )
810883 self .static_edge_properties = set (static_edge_properties )
884+ self .json_schema = json_schema
811885
812886 def evolve (self , graph_documents : List [GraphDocument ]) -> List [str ]:
813887 """Evolves current schema into a schema representing the input documents.
@@ -861,11 +935,13 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None:
861935 node_schema = ElementSchema .from_info_schema (node , decl_by_types )
862936 self ._update_node_schema (node_schema )
863937 self ._update_labels_and_properties (node_schema )
938+ self ._update_json_property_schema (node_schema )
864939
865940 for edge in info_schema .get ("edgeTables" , []):
866941 edge_schema = ElementSchema .from_info_schema (edge , decl_by_types )
867942 self ._update_edge_schema (edge_schema )
868943 self ._update_labels_and_properties (edge_schema )
944+ self ._update_json_property_schema (edge_schema )
869945
870946 def node_type_name (self , name : str ) -> str :
871947 return NODE_KIND if self .use_flexible_schema else name
@@ -952,26 +1028,36 @@ def __repr__(self) -> str:
9521028 triplets_per_label .setdefault (label , []).append (
9531029 (source_node , edge , target_node )
9541030 )
1031+
1032+ def repr_property (lname , pname , ptype , json_fields ):
1033+ if not json_fields :
1034+ return {"name" : pname , "type" : ptype }
1035+ return {"name" : pname , "type" : ptype , "json_fields" : json_fields }
1036+
9551037 return json .dumps (
9561038 {
9571039 "Name of graph" : self .graph_name ,
9581040 "Node properties per node label" : {
9591041 label : [
960- {
961- "name" : name ,
962- "type" : properties [name ],
963- }
964- for name in sorted (self .labels [label ].prop_names )
1042+ repr_property (
1043+ label ,
1044+ pname ,
1045+ properties [pname ],
1046+ self .node_json_property_schema .get (label , {}).get (pname ),
1047+ )
1048+ for pname in sorted (self .labels [label ].prop_names )
9651049 ]
9661050 for label in sorted (node_labels )
9671051 },
9681052 "Edge properties per edge label" : {
9691053 label : [
970- {
971- "name" : name ,
972- "type" : properties [name ],
973- }
974- for name in sorted (self .labels [label ].prop_names )
1054+ repr_property (
1055+ label ,
1056+ pname ,
1057+ properties [pname ],
1058+ self .edge_json_property_schema .get (label , {}).get (pname ),
1059+ )
1060+ for pname in sorted (self .labels [label ].prop_names )
9751061 ]
9761062 for label in sorted (edge_labels )
9771063 },
@@ -1124,6 +1210,37 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
11241210 self .edges [edge_schema .name ] = old_schema or edge_schema
11251211 return ddls
11261212
1213+ def _update_json_property_schema (self , element_schema : ElementSchema ) -> None :
1214+ if self .json_schema is None :
1215+ return
1216+ if len (element_schema .labels ) == 0 :
1217+ return
1218+ lname = element_schema .labels [0 ]
1219+ if element_schema .kind == NODE_KIND :
1220+ json_property_schema = self .json_schema .get_node_json_property_schema (
1221+ lname ,
1222+ [
1223+ pname
1224+ for pname , ptype in element_schema .types .items ()
1225+ if ptype == param_types .JSON
1226+ ],
1227+ )
1228+ self .node_json_property_schema .update (
1229+ {l : json_property_schema for l in element_schema .labels }
1230+ )
1231+ else :
1232+ json_property_schema = self .json_schema .get_edge_json_property_schema (
1233+ lname ,
1234+ [
1235+ pname
1236+ for pname , ptype in element_schema .types .items ()
1237+ if ptype == param_types .JSON
1238+ ],
1239+ )
1240+ self .edge_json_property_schema .update (
1241+ {l : json_property_schema for l in element_schema .labels }
1242+ )
1243+
11271244 def _update_labels_and_properties (self , element_schema : ElementSchema ) -> None :
11281245 """Updates labels and properties based on an element schema.
11291246
@@ -1176,7 +1293,6 @@ def add_edges(
11761293 """
11771294 edge_schema = self .get_edge_schema (self .edge_type_name (name ))
11781295 if edge_schema is None :
1179- print (list (self .edges .keys ()))
11801296 raise ValueError ("Unknown edge schema `%s`" % name )
11811297 for v in edge_schema .add_edges (name , edges ):
11821298 yield v
@@ -1265,7 +1381,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None:
12651381 return
12661382
12671383 op = self .database .update_ddl (ddl_statements = ddls )
1268- print ("Waiting for DDL operations to complete..." )
1384+ logger . info ("Waiting for DDL operations to complete..." )
12691385 return op .result (options .get ("timeout" , DEFAULT_DDL_TIMEOUT ))
12701386
12711387 def insert_or_update (
@@ -1291,6 +1407,7 @@ def __init__(
12911407 static_edge_properties : List [str ] = [],
12921408 impl : Optional [SpannerInterface ] = None ,
12931409 timeout : Optional [float ] = None ,
1410+ include_json_schema : bool = False ,
12941411 ):
12951412 """Initializes SpannerGraphStore.
12961413
@@ -1306,7 +1423,9 @@ def __init__(
13061423 static_edge_properties: in flexible schema, treat these edge
13071424 properties as static.
13081425 timeout (Optional[float]): The timeout for queries in seconds.
1426+ include_json_schema (Optional[bool]): Whether to include json fields in the schema.
13091427 """
1428+ self .graph_name = graph_name
13101429 self .impl = impl or SpannerImpl (
13111430 instance_id ,
13121431 database_id ,
@@ -1318,6 +1437,9 @@ def __init__(
13181437 use_flexible_schema ,
13191438 static_node_properties ,
13201439 static_edge_properties ,
1440+ json_schema = (
1441+ JsonSchema (graph_name , self .impl ) if include_json_schema else None
1442+ ),
13211443 )
13221444
13231445 self .refresh_schema ()
@@ -1345,25 +1467,28 @@ def add_graph_documents(
13451467 ddls = self .schema .evolve (graph_documents )
13461468 if ddls :
13471469 self .impl .apply_ddls (ddls )
1348- self .refresh_schema ()
13491470 else :
1350- print ("No schema change required..." )
1471+ logger . info ("No schema change required..." )
13511472
13521473 nodes , edges = partition_graph_docs (graph_documents )
13531474 for name , elements in nodes .items ():
13541475 if len (elements ) == 0 :
13551476 continue
13561477 for table , columns , rows in self .schema .add_nodes (name , elements ):
1357- print ("Insert nodes of type `{}`..." .format (name ))
1478+ logger . info ("Insert nodes of type `{}`..." .format (name ))
13581479 self .impl .insert_or_update (table , columns , rows )
13591480
13601481 for name , elements in edges .items ():
13611482 if len (elements ) == 0 :
13621483 continue
13631484 for table , columns , rows in self .schema .add_edges (name , elements ):
1364- print ("Insert edges of type `{}`..." .format (name ))
1485+ logger . info ("Insert edges of type `{}`..." .format (name ))
13651486 self .impl .insert_or_update (table , columns , rows )
13661487
1488+ # Refresh schema after data insertion because json property is sampled
1489+ # over the actual data.
1490+ self .refresh_schema ()
1491+
13671492 def query (self , query : str , params : dict = {}) -> List [Dict [str , Any ]]:
13681493 """Query Spanner database.
13691494
@@ -1435,5 +1560,8 @@ def cleanup(self):
14351560 ]
14361561 )
14371562 self .schema = SpannerGraphSchema (
1438- self .schema .graph_name , self .schema .use_flexible_schema
1563+ self .schema .graph_name , self .schema .use_flexible_schema ,
1564+ self .schema .static_node_properties ,
1565+ self .schema .static_edge_properties ,
1566+ self .schema .json_schema
14391567 )
0 commit comments