11"""RelBench integration utilities for PyTorch Geometric.
22
33Provides utilities for converting RelBench datasets to PyG HeteroData objects
4- with semantic embeddings and graph structure for reverse engineering tasks .
4+ with semantic embeddings and graph structure for warehouse applications .
55
6- TODO: Add subgraph sampling utilities for few-shot inference
7- TODO: Implement more sophisticated edge weighting schemes
8- TODO: Add support for temporal lineage tracking
6+ TODO: Add subgraph sampling utilities for inference
7+ TODO: Implement configurable edge weighting schemes
8+ TODO: Add support for lineage tracking
99"""
1010
1111import warnings
6969
7070
7171class RelBenchProcessor :
72- """Converts RelBench datasets to PyG HeteroData with unified records ."""
72+ """Utility for converting RelBench datasets to PyG HeteroData format ."""
7373 def __init__ (self , sbert_model : str = 'all-MiniLM-L6-v2' ) -> None :
7474 """Initialize processor with SBERT model."""
7575 if not RELBENCH_AVAILABLE :
@@ -500,7 +500,7 @@ def _can_infer_anomalies(self, table_name: Optional[str], db: Any) -> bool:
500500 # Structural inference methods
501501 def _infer_lineage_from_structure (self , table_name : str ,
502502 db : Any ) -> torch .Tensor :
503- """Infer ETL lineage stage from table structure ."""
503+ """Generate lineage labels using table metadata ."""
504504 table_df = db .table_dict [table_name ].df
505505
506506 # Count foreign key columns
@@ -535,7 +535,7 @@ def _infer_lineage_from_structure(self, table_name: str,
535535
536536 def _infer_silo_from_connectivity (self , table_name : str , db : Any ,
537537 num_nodes : int ) -> torch .Tensor :
538- """Infer silo detection labels from table connectivity."""
538+ """Generate silo labels using connectivity information ."""
539539 # Count connections to other tables
540540 connections = 0
541541
@@ -559,7 +559,7 @@ def _infer_silo_from_connectivity(self, table_name: str, db: Any,
559559
560560 def _infer_anomalies_from_statistics (self , table_name : str ,
561561 db : Any ) -> torch .Tensor :
562- """Infer anomaly detection labels from statistical analysis ."""
562+ """Generate anomaly labels using statistical methods ."""
563563 table_df = db .table_dict [table_name ].df
564564 num_nodes = len (table_df )
565565
@@ -590,7 +590,7 @@ def _infer_anomalies_from_statistics(self, table_name: str,
590590
591591 def _infer_record_lineage (self , table_name : str , db : Any ,
592592 num_records : int ) -> torch .Tensor :
593- """Infer lineage labels: 0=source, 1=intermediate, 2=target ."""
593+ """Generate lineage labels for individual records ."""
594594 table_df = db .table_dict [table_name ].df
595595
596596 # Count foreign keys in this table
@@ -616,7 +616,7 @@ def _infer_record_lineage(self, table_name: str, db: Any,
616616
617617 def _infer_record_silo (self , table_name : str , db : Any ,
618618 num_records : int ) -> torch .Tensor :
619- """Infer silo labels for individual records."""
619+ """Generate silo labels for individual records."""
620620 # Check table connectivity
621621 has_connections = False
622622 if hasattr (db , 'fkey_dict' ):
@@ -631,12 +631,12 @@ def _infer_record_silo(self, table_name: str, db: Any,
631631
632632 def _infer_record_anomaly (self , table_name : str , db : Any ,
633633 num_records : int ) -> torch .Tensor :
634- """Infer anomaly labels for individual records."""
634+ """Generate anomaly labels for individual records."""
635635 # Use existing statistical inference but return per-record labels
636636 return self ._infer_anomalies_from_statistics (table_name , db )
637637
638638 def _create_edges (self , hetero_data : HeteroData , db : Any ) -> None :
639- """Create edges for unified record space ."""
639+ """Create graph edges between records ."""
640640 if 'record' not in hetero_data .node_types :
641641 warnings .warn ('No record nodes found for edge creation' ,
642642 stacklevel = 2 )
@@ -649,7 +649,7 @@ def _create_edges(self, hetero_data: HeteroData, db: Any) -> None:
649649 self ._add_value_similarity_edges (hetero_data )
650650
651651 def _add_fk_edges_unified (self , hetero_data : HeteroData , db : Any ) -> None :
652- """Add FK -based edges between records in unified space ."""
652+ """Add relationship -based edges between records."""
653653 if not hasattr (db , 'fkey_dict' ) or not db .fkey_dict :
654654 return
655655
@@ -697,7 +697,7 @@ def _add_fk_edges_unified(self, hetero_data: HeteroData, db: Any) -> None:
697697 'record' ].edge_index = fk_edge_index
698698
699699 def _add_value_similarity_edges (self , hetero_data : HeteroData ) -> None :
700- """Add value similarity edges using existing embeddings ."""
700+ """Add similarity-based edges between nodes ."""
701701 if 'record' not in hetero_data .node_types :
702702 return
703703
@@ -772,7 +772,7 @@ def _add_sample_edges(self, hetero_data: HeteroData, src: str, rel: str,
772772 def _discover_real_relationships (
773773 self , db : Any ,
774774 node_types : List [str ]) -> List [Tuple [str , str , str ]]:
775- """Discover real relationships from RelBench metadata."""
775+ """Extract relationships from RelBench metadata."""
776776 relationships = []
777777
778778 # Method 1: Try to use edge_df if available
@@ -810,7 +810,7 @@ def _discover_real_relationships(
810810 def _infer_relationships_from_columns (
811811 self , db : Any ,
812812 node_types : List [str ]) -> List [Tuple [str , str , str ]]:
813- """Infer relationships from foreign key column patterns."""
813+ """Extract relationships from column patterns."""
814814 relationships = []
815815
816816 for table_name , table_obj in db .table_dict .items ():
@@ -847,7 +847,7 @@ def _add_real_edges_from_fk(
847847 dst_table : str ,
848848 db : Any ,
849849 ) -> None :
850- """Create real edges based on actual foreign key values ."""
850+ """Create edges using available relationship data ."""
851851 try :
852852 src_df = db .table_dict [src_table ].df
853853 dst_df = db .table_dict [dst_table ].df
@@ -919,10 +919,10 @@ def create_relbench_hetero_data(
919919 use_dummy_fallback : bool = False ,
920920 batch_size : int = 64 ,
921921) -> HeteroData :
922- """Create HeteroData from RelBench dataset with unified record nodes .
922+ """Create HeteroData from RelBench dataset.
923923
924924 TODO: Add support for custom edge types and weights
925- TODO: Implement temporal lineage tracking
925+ TODO: Implement lineage tracking
926926 """
927927 processor = RelBenchProcessor (sbert_model )
928928
@@ -968,21 +968,17 @@ def get_warehouse_task_info() -> Dict[str, Dict[str, Any]]:
968968 'data_availability' :
969969 'real_data' ,
970970 'notes' :
971- 'Based on actual foreign key connectivity analysis '
971+ 'Based on connectivity information '
972972 'from RelBench metadata.' ,
973973 },
974974 'anomaly' : {
975- 'num_classes' :
976- 2 ,
975+ 'num_classes' : 2 ,
977976 'classes' : ['normal' , 'anomaly' ],
978- 'description' :
979- 'Anomaly detection - identify unusual patterns '
977+ 'description' : 'Anomaly detection - identify unusual patterns '
980978 'in data warehouse' ,
981- 'data_availability' :
982- 'statistical_inference' ,
983- 'notes' :
984- 'Based on statistical outlier detection using IQR '
985- 'method on numeric columns.' ,
979+ 'data_availability' : 'statistical_inference' ,
980+ 'notes' : 'Based on statistical methods '
981+ 'applied to numeric columns.' ,
986982 },
987983 }
988984
0 commit comments