|
| 1 | +"""RelBench Data Warehouse Integration Demo for PyTorch Geometric. |
| 2 | +
|
| 3 | +This example demonstrates RelBench dataset conversion to PyG HeteroData format |
| 4 | +for data warehouse applications including lineage tracking, silo detection, |
| 5 | +and anomaly identification. |
| 6 | +
|
| 7 | +Complements examples/rdl.py by providing warehouse-specific utilities |
| 8 | +and G-Retriever preparation for future LLM integration. |
| 9 | +
|
| 10 | +Requirements: |
| 11 | +`pip install relbench[full] sentence-transformers` |
| 12 | +
|
| 13 | +Paper references: |
| 14 | +- RelBench: https://arxiv.org/abs/2407.20060 |
| 15 | +- G-Retriever: https://arxiv.org/abs/2402.07630 |
| 16 | +""" |
| 17 | +import argparse |
| 18 | +import time |
| 19 | +from typing import Any, Dict, Tuple |
| 20 | + |
| 21 | +import torch |
| 22 | + |
| 23 | +from torch_geometric import seed_everything |
| 24 | +from torch_geometric.data import HeteroData |
| 25 | +from torch_geometric.datasets.relbench import ( |
| 26 | + RelBenchProcessor, |
| 27 | + create_relbench_hetero_data, |
| 28 | + get_warehouse_task_info, |
| 29 | + prepare_for_gretriever, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +def demonstrate_basic_usage(dataset_name: str, |
| 34 | + sample_size: int = 100) -> HeteroData: |
| 35 | + """Demonstrate basic RelBench to PyG conversion. |
| 36 | +
|
| 37 | + Args: |
| 38 | + dataset_name: RelBench dataset name (e.g., 'rel-trial') |
| 39 | + sample_size: Number of records to sample |
| 40 | +
|
| 41 | + Returns: |
| 42 | + HeteroData object with warehouse enhancements |
| 43 | + """ |
| 44 | + print(f"Converting RelBench dataset '{dataset_name}' to PyG format...") |
| 45 | + start_time = time.time() |
| 46 | + |
| 47 | + hetero_data = create_relbench_hetero_data(dataset_name=dataset_name, |
| 48 | + sample_size=sample_size, |
| 49 | + add_warehouse_labels=True) |
| 50 | + |
| 51 | + conversion_time = time.time() - start_time |
| 52 | + print(f"Conversion completed in {conversion_time:.2f}s") |
| 53 | + |
| 54 | + # Display basic statistics |
| 55 | + print("Graph Statistics:") |
| 56 | + print(f" Node types: {list(hetero_data.node_types)}") |
| 57 | + print(f" Edge types: {list(hetero_data.edge_types)}") |
| 58 | + total_nodes = sum(hetero_data[node_type].num_nodes |
| 59 | + for node_type in hetero_data.node_types) |
| 60 | + total_edges = sum(hetero_data[edge_type].num_edges |
| 61 | + for edge_type in hetero_data.edge_types) |
| 62 | + print(f" Total nodes: {total_nodes}") |
| 63 | + print(f" Total edges: {total_edges}") |
| 64 | + |
| 65 | + return hetero_data |
| 66 | + |
| 67 | + |
| 68 | +def demonstrate_warehouse_tasks() -> Dict[str, Any]: |
| 69 | + """Demonstrate warehouse-specific task definitions. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + Dictionary containing warehouse task information |
| 73 | + """ |
| 74 | + print("\nWarehouse Task Information:") |
| 75 | + |
| 76 | + task_info = get_warehouse_task_info() |
| 77 | + |
| 78 | + for task_name, task_data in task_info.items(): |
| 79 | + print(f" {task_name.upper()}:") |
| 80 | + classes_str = ', '.join(task_data['classes']) |
| 81 | + print(f" Classes: {task_data['num_classes']} ({classes_str})") |
| 82 | + print(f" Description: {task_data['description']}") |
| 83 | + |
| 84 | + return task_info |
| 85 | + |
| 86 | + |
| 87 | +def demonstrate_processor_usage( |
| 88 | + sbert_model: str = 'all-MiniLM-L6-v2') -> RelBenchProcessor: |
| 89 | + """Demonstrate RelBenchProcessor with custom SBERT model. |
| 90 | +
|
| 91 | + Args: |
| 92 | + sbert_model: SBERT model name for embeddings |
| 93 | +
|
| 94 | + Returns: |
| 95 | + Configured RelBenchProcessor instance |
| 96 | + """ |
| 97 | + print("\nRelBench Processor Configuration:") |
| 98 | + |
| 99 | + processor = RelBenchProcessor(sbert_model=sbert_model) |
| 100 | + print(f" Model: {processor.sbert_model_name}") |
| 101 | + print(" Embedding dimension: 384 (SBERT)") |
| 102 | + print(" Optimized for: Semantic similarity and Q&A tasks") |
| 103 | + |
| 104 | + return processor |
| 105 | + |
| 106 | + |
| 107 | +def demonstrate_gretriever_preparation( |
| 108 | + hetero_data: HeteroData) -> Tuple[HeteroData, Dict[str, Any]]: |
| 109 | + """Demonstrate G-Retriever preparation for future LLM integration. |
| 110 | +
|
| 111 | + Args: |
| 112 | + hetero_data: Input HeteroData object |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Tuple of (enhanced_hetero_data, metadata_dict) |
| 116 | + """ |
| 117 | + print("\nG-Retriever Preparation:") |
| 118 | + |
| 119 | + enhanced_data, metadata = prepare_for_gretriever(hetero_data) |
| 120 | + |
| 121 | + print(f" G-Retriever ready: {enhanced_data.gretriever_ready}") |
| 122 | + print(f" Embedding type: {enhanced_data.embedding_type}") |
| 123 | + print(f" Warehouse enhanced: {enhanced_data.warehouse_enhanced}") |
| 124 | + |
| 125 | + print(f" Metadata keys: {list(metadata.keys())}") |
| 126 | + print(f" Warehouse tasks: {metadata['warehouse_tasks']}") |
| 127 | + print(f" Conversion ready: {metadata['conversion_ready']}") |
| 128 | + |
| 129 | + return enhanced_data, metadata |
| 130 | + |
| 131 | + |
| 132 | +def save_demo_results(hetero_data: HeteroData, metadata: Dict[str, Any], |
| 133 | + save_path: str = 'relbench_demo_output.pt'): |
| 134 | + """Save demonstration results for future use. |
| 135 | +
|
| 136 | + Args: |
| 137 | + hetero_data: Enhanced HeteroData object |
| 138 | + metadata: G-Retriever metadata |
| 139 | + save_path: Path to save results |
| 140 | + """ |
| 141 | + print("\nSaving Results:") |
| 142 | + |
| 143 | + results = { |
| 144 | + 'hetero_data': hetero_data, |
| 145 | + 'metadata': metadata, |
| 146 | + 'timestamp': time.time() |
| 147 | + } |
| 148 | + |
| 149 | + torch.save(results, save_path) |
| 150 | + print(f" Saved to: {save_path}") |
| 151 | + |
| 152 | + |
| 153 | +def main(dataset_name: str, sample_size: int, sbert_model: str, |
| 154 | + save_results: bool, seed: int): |
| 155 | + """Main demonstration function. |
| 156 | +
|
| 157 | + Args: |
| 158 | + dataset_name: RelBench dataset name |
| 159 | + sample_size: Number of records to sample |
| 160 | + sbert_model: SBERT model for embeddings |
| 161 | + save_results: Whether to save results |
| 162 | + seed: Random seed for reproducibility |
| 163 | + """ |
| 164 | + seed_everything(seed) |
| 165 | + |
| 166 | + print("RelBench Data Warehouse Integration Demo") |
| 167 | + print("=" * 60) |
| 168 | + |
| 169 | + try: |
| 170 | + hetero_data = demonstrate_basic_usage(dataset_name, sample_size) |
| 171 | + demonstrate_warehouse_tasks() |
| 172 | + demonstrate_processor_usage(sbert_model) |
| 173 | + enhanced_data, metadata = demonstrate_gretriever_preparation( |
| 174 | + hetero_data) |
| 175 | + |
| 176 | + if save_results: |
| 177 | + save_demo_results(enhanced_data, metadata) |
| 178 | + |
| 179 | + print("\nDemo completed successfully!") |
| 180 | + print("Ready for G-Retriever integration and warehouse Q&A " |
| 181 | + "applications") |
| 182 | + |
| 183 | + except ImportError as e: |
| 184 | + if 'relbench' in str(e).lower(): |
| 185 | + print(f"RelBench not available: {e}") |
| 186 | + print("Install with: pip install relbench[full] " |
| 187 | + "sentence-transformers") |
| 188 | + else: |
| 189 | + raise e |
| 190 | + except Exception as e: |
| 191 | + print(f"Demo failed: {e}") |
| 192 | + raise e |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == '__main__': |
| 196 | + parser = argparse.ArgumentParser( |
| 197 | + description='RelBench Data Warehouse Integration Demo') |
| 198 | + parser.add_argument('--dataset', type=str, default='rel-trial', |
| 199 | + help='RelBench dataset name (default: rel-trial)') |
| 200 | + parser.add_argument('--sample_size', type=int, default=100, |
| 201 | + help='Number of records to sample (default: 100)') |
| 202 | + parser.add_argument( |
| 203 | + '--sbert_model', type=str, default='all-MiniLM-L6-v2', |
| 204 | + help='SBERT model for embeddings ' |
| 205 | + '(default: all-MiniLM-L6-v2)') |
| 206 | + parser.add_argument('--save_results', action='store_true', |
| 207 | + help='Save demonstration results to file') |
| 208 | + parser.add_argument('--seed', type=int, default=42, |
| 209 | + help='Random seed for reproducibility (default: 42)') |
| 210 | + |
| 211 | + args = parser.parse_args() |
| 212 | + |
| 213 | + start_time = time.time() |
| 214 | + main(args.dataset, args.sample_size, args.sbert_model, args.save_results, |
| 215 | + args.seed) |
| 216 | + total_time = time.time() - start_time |
| 217 | + print(f"\nTotal execution time: {total_time:.2f}s") |
0 commit comments