@@ -671,7 +671,148 @@ def _estimate_json_size(df: pandas.DataFrame) -> int:
671671 return int (key_overhead + structural_overhead + total_val_len )
672672
673673
674- def _add_graph_widget (query_result : pandas .DataFrame , query_job : Any , args : Any ):
674+ def _convert_schema (schema_json : str ) -> str :
675+ """
676+ Converts a JSON string from the BigQuery schema format to the format
677+ expected by the visualization framework.
678+
679+ Args:
680+ schema_json: The input JSON string in the BigQuery schema format.
681+
682+ Returns:
683+ The converted JSON string in the visualization framework format.
684+ """
685+ data = json .loads (schema_json )
686+
687+ graph_id = data .get ("propertyGraphReference" , {}).get (
688+ "propertyGraphId" , "SampleGraph"
689+ )
690+
691+ output = {
692+ "catalog" : "" ,
693+ "name" : graph_id ,
694+ "schema" : "" ,
695+ "labels" : [],
696+ "nodeTables" : [],
697+ "edgeTables" : [],
698+ "propertyDeclarations" : [],
699+ }
700+
701+ labels_dict = {} # name -> set of property names
702+ props_dict = {} # name -> type
703+
704+ def process_table (table , kind ):
705+ name = table .get ("name" )
706+ base_table_name = table .get ("dataSourceTable" , {}).get ("tableId" )
707+ key_columns = table .get ("keyColumns" , [])
708+
709+ label_names = []
710+ property_definitions = []
711+
712+ for lp in table .get ("labelAndProperties" , []):
713+ label = lp .get ("label" )
714+ label_names .append (label )
715+
716+ if label not in labels_dict :
717+ labels_dict [label ] = set ()
718+
719+ for prop in lp .get ("properties" , []):
720+ prop_name = prop .get ("name" )
721+ prop_type = prop .get ("dataType" , {}).get ("typeKind" )
722+ prop_expr = prop .get ("expression" )
723+
724+ labels_dict [label ].add (prop_name )
725+ props_dict [prop_name ] = prop_type
726+
727+ property_definitions .append (
728+ {
729+ "propertyDeclarationName" : prop_name ,
730+ "valueExpressionSql" : prop_expr ,
731+ }
732+ )
733+
734+ entry = {
735+ "name" : name ,
736+ "baseTableName" : base_table_name ,
737+ "kind" : kind ,
738+ "labelNames" : label_names ,
739+ "keyColumns" : key_columns ,
740+ "propertyDefinitions" : property_definitions ,
741+ }
742+
743+ if kind == "EDGE" :
744+ src = table .get ("sourceNodeReference" , {})
745+ dst = table .get ("destinationNodeReference" , {})
746+
747+ entry ["sourceNodeTable" ] = {
748+ "nodeTableName" : src .get ("nodeTable" ),
749+ "edgeTableColumns" : src .get ("edgeTableColumns" ),
750+ "nodeTableColumns" : src .get ("nodeTableColumns" ),
751+ }
752+ entry ["destinationNodeTable" ] = {
753+ "nodeTableName" : dst .get ("nodeTable" ),
754+ "edgeTableColumns" : dst .get ("edgeTableColumns" ),
755+ "nodeTableColumns" : dst .get ("nodeTableColumns" ),
756+ }
757+
758+ return entry
759+
760+ for nt in data .get ("nodeTables" , []):
761+ output ["nodeTables" ].append (process_table (nt , "NODE" ))
762+
763+ for et in data .get ("edgeTables" , []):
764+ output ["edgeTables" ].append (process_table (et , "EDGE" ))
765+
766+ for label_name , prop_names in labels_dict .items ():
767+ output ["labels" ].append (
768+ {
769+ "name" : label_name ,
770+ "propertyDeclarationNames" : sorted (list (prop_names )),
771+ }
772+ )
773+
774+ for prop_name , prop_type in props_dict .items ():
775+ output ["propertyDeclarations" ].append ({"name" : prop_name , "type" : prop_type })
776+
777+ return json .dumps (output , indent = 2 )
778+
779+
780+ def _get_graph_name (query_text : str ):
781+ """Returns the name of the graph queried.
782+
783+ Supports GRAPH only, not GRAPH_TABLE.
784+
785+ Args:
786+ query_text: The SQL query text.
787+
788+ Returns:
789+ A (dataset_id, graph_id) tuple, or None if the graph name cannot be determined.
790+ """
791+ match = re .match (r"\s*GRAPH\s+(\S+)\.(\S+)" , query_text , re .IGNORECASE )
792+ if match :
793+ return (match .group (1 ), match .group (2 ))
794+ return None
795+
796+
797+ def _get_graph_schema (bq_client : bigquery .client .Client , query_text : str , query_job : bigquery .job .QueryJob ):
798+ graph_name_result = _get_graph_name (query_text )
799+ if graph_name_result is None :
800+ return None
801+ dataset_id , graph_id = graph_name_result
802+
803+ info_schema_query = f'''
804+ select PROPERTY_GRAPH_METADATA_JSON
805+ FROM `{ query_job .configuration .destination .project } .{ dataset_id } `.INFORMATION_SCHEMA.PROPERTY_GRAPHS
806+ WHERE PROPERTY_GRAPH_NAME = "{ graph_id } "
807+ '''
808+ info_schema_results = bq_client .query (info_schema_query ).to_dataframe ()
809+
810+ if info_schema_results .shape == (1 , 1 ):
811+ return _convert_schema (info_schema_results .iloc [0 , 0 ])
812+ return None
813+
814+
815+ def _add_graph_widget (bq_client : Any , query_result : pandas .DataFrame , query_text : str , query_job : Any , args : Any ):
675816 try :
676817 from spanner_graphs .graph_visualization import generate_visualization_html
677818 except ImportError as err :
@@ -723,6 +864,8 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any)
723864 )
724865 return
725866
867+ schema = _get_graph_schema (bq_client , query_text , query_job )
868+
726869 table_dict = {
727870 "projectId" : query_job .configuration .destination .project ,
728871 "datasetId" : query_job .configuration .destination .dataset_id ,
@@ -733,6 +876,9 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any)
733876 if estimated_size < MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE :
734877 params_dict ["query_result" ] = json .loads (query_result .to_json ())
735878
879+ if schema is not None :
880+ params_dict ["schema" ] = schema
881+
736882 params_str = json .dumps (params_dict )
737883 html_content = generate_visualization_html (
738884 query = "placeholder query" ,
@@ -746,13 +892,6 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any)
746892 '"graph_visualization.NodeExpansion"' ,
747893 '"bigquery.graph_visualization.NodeExpansion"' ,
748894 )
749- html_content = html_content .replace (
750- '"graph_visualization.Query"' , '"bigquery.graph_visualization.Query"'
751- )
752- html_content = html_content .replace (
753- '"graph_visualization.NodeExpansion"' ,
754- '"bigquery.graph_visualization.NodeExpansion"' ,
755- )
756895 IPython .display .display (IPython .core .display .HTML (html_content ))
757896
758897
@@ -860,7 +999,7 @@ def _make_bq_query(
860999 result = result .to_dataframe (** dataframe_kwargs )
8611000
8621001 if args .graph and _supports_graph_widget (result ):
863- _add_graph_widget (result , query_job , args )
1002+ _add_graph_widget (bq_client , result , query , query_job , args )
8641003 return _handle_result (result , args )
8651004
8661005
0 commit comments