Skip to content

Commit 87e414c

Browse files
committed
feat: add streamlined RelBench warehouse demo with API alignment
- Add examples/llm/relbench_warehouse_demo.py following PyG LLM patterns - Demonstrate RelBench to PyG conversion with warehouse tasks - Include G-Retriever preparation for future LLM integration - Full CLI interface with argparse following PyG conventions - Comprehensive error handling and user guidance - 100% flake8/ruff/yapf/isort compliance with proper type hints and docstrings - Complements existing examples/rdl.py without duplication Addresses maintainer feedback on API alignment and streamlined approach. Ready for G-Retriever 'talk to your data warehouse' implementation.
1 parent 4e4e854 commit 87e414c

1 file changed

Lines changed: 217 additions & 0 deletions

File tree

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""RelBench Data Warehouse Integration Demo for PyTorch Geometric.
2+
3+
This example demonstrates RelBench dataset conversion to PyG HeteroData format
4+
for data warehouse applications including lineage tracking, silo detection,
5+
and anomaly identification.
6+
7+
Complements examples/rdl.py by providing warehouse-specific utilities
8+
and G-Retriever preparation for future LLM integration.
9+
10+
Requirements:
11+
`pip install relbench[full] sentence-transformers`
12+
13+
Paper references:
14+
- RelBench: https://arxiv.org/abs/2407.20060
15+
- G-Retriever: https://arxiv.org/abs/2402.07630
16+
"""
17+
import argparse
18+
import time
19+
from typing import Any, Dict, Tuple
20+
21+
import torch
22+
23+
from torch_geometric import seed_everything
24+
from torch_geometric.data import HeteroData
25+
from torch_geometric.datasets.relbench import (
26+
RelBenchProcessor,
27+
create_relbench_hetero_data,
28+
get_warehouse_task_info,
29+
prepare_for_gretriever,
30+
)
31+
32+
33+
def demonstrate_basic_usage(dataset_name: str,
34+
sample_size: int = 100) -> HeteroData:
35+
"""Demonstrate basic RelBench to PyG conversion.
36+
37+
Args:
38+
dataset_name: RelBench dataset name (e.g., 'rel-trial')
39+
sample_size: Number of records to sample
40+
41+
Returns:
42+
HeteroData object with warehouse enhancements
43+
"""
44+
print(f"Converting RelBench dataset '{dataset_name}' to PyG format...")
45+
start_time = time.time()
46+
47+
hetero_data = create_relbench_hetero_data(dataset_name=dataset_name,
48+
sample_size=sample_size,
49+
add_warehouse_labels=True)
50+
51+
conversion_time = time.time() - start_time
52+
print(f"Conversion completed in {conversion_time:.2f}s")
53+
54+
# Display basic statistics
55+
print("Graph Statistics:")
56+
print(f" Node types: {list(hetero_data.node_types)}")
57+
print(f" Edge types: {list(hetero_data.edge_types)}")
58+
total_nodes = sum(hetero_data[node_type].num_nodes
59+
for node_type in hetero_data.node_types)
60+
total_edges = sum(hetero_data[edge_type].num_edges
61+
for edge_type in hetero_data.edge_types)
62+
print(f" Total nodes: {total_nodes}")
63+
print(f" Total edges: {total_edges}")
64+
65+
return hetero_data
66+
67+
68+
def demonstrate_warehouse_tasks() -> Dict[str, Any]:
69+
"""Demonstrate warehouse-specific task definitions.
70+
71+
Returns:
72+
Dictionary containing warehouse task information
73+
"""
74+
print("\nWarehouse Task Information:")
75+
76+
task_info = get_warehouse_task_info()
77+
78+
for task_name, task_data in task_info.items():
79+
print(f" {task_name.upper()}:")
80+
classes_str = ', '.join(task_data['classes'])
81+
print(f" Classes: {task_data['num_classes']} ({classes_str})")
82+
print(f" Description: {task_data['description']}")
83+
84+
return task_info
85+
86+
87+
def demonstrate_processor_usage(
88+
sbert_model: str = 'all-MiniLM-L6-v2') -> RelBenchProcessor:
89+
"""Demonstrate RelBenchProcessor with custom SBERT model.
90+
91+
Args:
92+
sbert_model: SBERT model name for embeddings
93+
94+
Returns:
95+
Configured RelBenchProcessor instance
96+
"""
97+
print("\nRelBench Processor Configuration:")
98+
99+
processor = RelBenchProcessor(sbert_model=sbert_model)
100+
print(f" Model: {processor.sbert_model_name}")
101+
print(" Embedding dimension: 384 (SBERT)")
102+
print(" Optimized for: Semantic similarity and Q&A tasks")
103+
104+
return processor
105+
106+
107+
def demonstrate_gretriever_preparation(
108+
hetero_data: HeteroData) -> Tuple[HeteroData, Dict[str, Any]]:
109+
"""Demonstrate G-Retriever preparation for future LLM integration.
110+
111+
Args:
112+
hetero_data: Input HeteroData object
113+
114+
Returns:
115+
Tuple of (enhanced_hetero_data, metadata_dict)
116+
"""
117+
print("\nG-Retriever Preparation:")
118+
119+
enhanced_data, metadata = prepare_for_gretriever(hetero_data)
120+
121+
print(f" G-Retriever ready: {enhanced_data.gretriever_ready}")
122+
print(f" Embedding type: {enhanced_data.embedding_type}")
123+
print(f" Warehouse enhanced: {enhanced_data.warehouse_enhanced}")
124+
125+
print(f" Metadata keys: {list(metadata.keys())}")
126+
print(f" Warehouse tasks: {metadata['warehouse_tasks']}")
127+
print(f" Conversion ready: {metadata['conversion_ready']}")
128+
129+
return enhanced_data, metadata
130+
131+
132+
def save_demo_results(hetero_data: HeteroData, metadata: Dict[str, Any],
133+
save_path: str = 'relbench_demo_output.pt'):
134+
"""Save demonstration results for future use.
135+
136+
Args:
137+
hetero_data: Enhanced HeteroData object
138+
metadata: G-Retriever metadata
139+
save_path: Path to save results
140+
"""
141+
print("\nSaving Results:")
142+
143+
results = {
144+
'hetero_data': hetero_data,
145+
'metadata': metadata,
146+
'timestamp': time.time()
147+
}
148+
149+
torch.save(results, save_path)
150+
print(f" Saved to: {save_path}")
151+
152+
153+
def main(dataset_name: str, sample_size: int, sbert_model: str,
154+
save_results: bool, seed: int):
155+
"""Main demonstration function.
156+
157+
Args:
158+
dataset_name: RelBench dataset name
159+
sample_size: Number of records to sample
160+
sbert_model: SBERT model for embeddings
161+
save_results: Whether to save results
162+
seed: Random seed for reproducibility
163+
"""
164+
seed_everything(seed)
165+
166+
print("RelBench Data Warehouse Integration Demo")
167+
print("=" * 60)
168+
169+
try:
170+
hetero_data = demonstrate_basic_usage(dataset_name, sample_size)
171+
demonstrate_warehouse_tasks()
172+
demonstrate_processor_usage(sbert_model)
173+
enhanced_data, metadata = demonstrate_gretriever_preparation(
174+
hetero_data)
175+
176+
if save_results:
177+
save_demo_results(enhanced_data, metadata)
178+
179+
print("\nDemo completed successfully!")
180+
print("Ready for G-Retriever integration and warehouse Q&A "
181+
"applications")
182+
183+
except ImportError as e:
184+
if 'relbench' in str(e).lower():
185+
print(f"RelBench not available: {e}")
186+
print("Install with: pip install relbench[full] "
187+
"sentence-transformers")
188+
else:
189+
raise e
190+
except Exception as e:
191+
print(f"Demo failed: {e}")
192+
raise e
193+
194+
195+
if __name__ == '__main__':
196+
parser = argparse.ArgumentParser(
197+
description='RelBench Data Warehouse Integration Demo')
198+
parser.add_argument('--dataset', type=str, default='rel-trial',
199+
help='RelBench dataset name (default: rel-trial)')
200+
parser.add_argument('--sample_size', type=int, default=100,
201+
help='Number of records to sample (default: 100)')
202+
parser.add_argument(
203+
'--sbert_model', type=str, default='all-MiniLM-L6-v2',
204+
help='SBERT model for embeddings '
205+
'(default: all-MiniLM-L6-v2)')
206+
parser.add_argument('--save_results', action='store_true',
207+
help='Save demonstration results to file')
208+
parser.add_argument('--seed', type=int, default=42,
209+
help='Random seed for reproducibility (default: 42)')
210+
211+
args = parser.parse_args()
212+
213+
start_time = time.time()
214+
main(args.dataset, args.sample_size, args.sbert_model, args.save_results,
215+
args.seed)
216+
total_time = time.time() - start_time
217+
print(f"\nTotal execution time: {total_time:.2f}s")

0 commit comments

Comments
 (0)