1616"""
1717import argparse
1818import time
19- from typing import Dict , Any , Tuple
19+ from typing import Any , Dict , Tuple
2020
2121import torch
22+
2223from torch_geometric import seed_everything
2324from torch_geometric .data import HeteroData
2425from torch_geometric .datasets .relbench import (
2526 RelBenchProcessor ,
2627 create_relbench_hetero_data ,
2728 get_warehouse_task_info ,
28- prepare_for_gretriever
29+ prepare_for_gretriever ,
2930)
3031
3132
@@ -44,11 +45,9 @@ def demonstrate_basic_usage(dataset_name: str,
4445 start_time = time .time ()
4546
4647 # Create HeteroData with warehouse labels
47- hetero_data = create_relbench_hetero_data (
48- dataset_name = dataset_name ,
49- sample_size = sample_size ,
50- add_warehouse_labels = True
51- )
48+ hetero_data = create_relbench_hetero_data (dataset_name = dataset_name ,
49+ sample_size = sample_size ,
50+ add_warehouse_labels = True )
5251
5352 conversion_time = time .time () - start_time
5453 print (f"✅ Conversion completed in { conversion_time :.2f} s" )
@@ -210,9 +209,10 @@ def main(dataset_name: str, sample_size: int, sbert_model: str,
210209 help = 'RelBench dataset name (default: rel-trial)' )
211210 parser .add_argument ('--sample_size' , type = int , default = 100 ,
212211 help = 'Number of records to sample (default: 100)' )
213- parser .add_argument ('--sbert_model' , type = str , default = 'all-MiniLM-L6-v2' ,
214- help = 'SBERT model for embeddings '
215- '(default: all-MiniLM-L6-v2)' )
212+ parser .add_argument (
213+ '--sbert_model' , type = str , default = 'all-MiniLM-L6-v2' ,
214+ help = 'SBERT model for embeddings '
215+ '(default: all-MiniLM-L6-v2)' )
216216 parser .add_argument ('--save_results' , action = 'store_true' ,
217217 help = 'Save demonstration results to file' )
218218 parser .add_argument ('--seed' , type = int , default = 42 ,
@@ -221,7 +221,7 @@ def main(dataset_name: str, sample_size: int, sbert_model: str,
221221 args = parser .parse_args ()
222222
223223 start_time = time .time ()
224- main (args .dataset , args .sample_size , args .sbert_model ,
225- args .save_results , args . seed )
224+ main (args .dataset , args .sample_size , args .sbert_model , args . save_results ,
225+ args .seed )
226226 total_time = time .time () - start_time
227227 print (f"\n Total execution time: { total_time :.2f} s" )
0 commit comments