|
1 | 1 | """RelBench integration utilities for PyTorch Geometric. |
2 | 2 |
|
3 | 3 | Provides utilities for converting RelBench datasets to PyG HeteroData objects |
4 | | -with semantic embeddings and graph structure for warehouse applications. |
| 4 | +with semantic embeddings and warehouse-specific enhancements. |
5 | 5 |
|
6 | | -TODO: Add subgraph sampling utilities for inference |
7 | | -TODO: Implement configurable edge weighting schemes |
8 | | -TODO: Add support for lineage tracking |
| 6 | +Complements examples/rdl.py with G-Retriever preparation and warehouse tasks. |
9 | 7 | """ |
10 | 8 |
|
11 | 9 | import warnings |
@@ -1044,6 +1042,35 @@ def process(self) -> None: |
1044 | 1042 | torch.save((collated_data, slices), self.processed_paths[0]) |
1045 | 1043 |
|
1046 | 1044 |
|
| 1045 | +def prepare_for_gretriever( |
| 1046 | + hetero_data: HeteroData) -> Tuple[HeteroData, Dict[str, Any]]: |
| 1047 | + """Prepare RelBench HeteroData for G-Retriever training. |
| 1048 | +
|
| 1049 | + Enhances HeteroData with G-Retriever-specific attributes and metadata. |
| 1050 | +
|
| 1051 | + Args: |
| 1052 | + hetero_data: HeteroData object from RelBench integration |
| 1053 | +
|
| 1054 | + Returns: |
| 1055 | + Tuple of (enhanced_hetero_data, metadata_dict) |
| 1056 | + """ |
| 1057 | + metadata = { |
| 1058 | + 'embedding_dim': getattr(hetero_data, 'embedding_dim', 384), |
| 1059 | + 'node_types': list(hetero_data.node_types), |
| 1060 | + 'edge_types': list(hetero_data.edge_types), |
| 1061 | + 'warehouse_tasks': ['lineage', 'silo', 'anomaly'], |
| 1062 | + 'recommended_qa_pairs': get_warehouse_task_info(), |
| 1063 | + 'conversion_ready': True, |
| 1064 | + } |
| 1065 | + |
| 1066 | + # Add G-Retriever specific attributes |
| 1067 | + hetero_data.gretriever_ready = True |
| 1068 | + hetero_data.embedding_type = 'sbert' # Indicates SBERT embeddings |
| 1069 | + hetero_data.warehouse_enhanced = True |
| 1070 | + |
| 1071 | + return hetero_data, metadata |
| 1072 | + |
| 1073 | + |
1047 | 1074 | # Backward compatibility aliases |
1048 | 1075 | RelBenchHeteroDataProcessor = RelBenchProcessor |
1049 | 1076 | create_hetero_data_from_relbench = create_relbench_hetero_data |
0 commit comments