Skip to content

Commit b2d30dc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 590f092 commit b2d30dc

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

examples/llm/relbench_warehouse_demo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
"""
1717
import argparse
1818
import time
19-
from typing import Dict, Any, Tuple
19+
from typing import Any, Dict, Tuple
2020

2121
import torch
22+
2223
from torch_geometric import seed_everything
2324
from torch_geometric.data import HeteroData
2425
from torch_geometric.datasets.relbench import (
2526
RelBenchProcessor,
2627
create_relbench_hetero_data,
2728
get_warehouse_task_info,
28-
prepare_for_gretriever
29+
prepare_for_gretriever,
2930
)
3031

3132

@@ -44,11 +45,9 @@ def demonstrate_basic_usage(dataset_name: str,
4445
start_time = time.time()
4546

4647
# Create HeteroData with warehouse labels
47-
hetero_data = create_relbench_hetero_data(
48-
dataset_name=dataset_name,
49-
sample_size=sample_size,
50-
add_warehouse_labels=True
51-
)
48+
hetero_data = create_relbench_hetero_data(dataset_name=dataset_name,
49+
sample_size=sample_size,
50+
add_warehouse_labels=True)
5251

5352
conversion_time = time.time() - start_time
5453
print(f"✅ Conversion completed in {conversion_time:.2f}s")
@@ -210,9 +209,10 @@ def main(dataset_name: str, sample_size: int, sbert_model: str,
210209
help='RelBench dataset name (default: rel-trial)')
211210
parser.add_argument('--sample_size', type=int, default=100,
212211
help='Number of records to sample (default: 100)')
213-
parser.add_argument('--sbert_model', type=str, default='all-MiniLM-L6-v2',
214-
help='SBERT model for embeddings '
215-
'(default: all-MiniLM-L6-v2)')
212+
parser.add_argument(
213+
'--sbert_model', type=str, default='all-MiniLM-L6-v2',
214+
help='SBERT model for embeddings '
215+
'(default: all-MiniLM-L6-v2)')
216216
parser.add_argument('--save_results', action='store_true',
217217
help='Save demonstration results to file')
218218
parser.add_argument('--seed', type=int, default=42,
@@ -221,7 +221,7 @@ def main(dataset_name: str, sample_size: int, sbert_model: str,
221221
args = parser.parse_args()
222222

223223
start_time = time.time()
224-
main(args.dataset, args.sample_size, args.sbert_model,
225-
args.save_results, args.seed)
224+
main(args.dataset, args.sample_size, args.sbert_model, args.save_results,
225+
args.seed)
226226
total_time = time.time() - start_time
227227
print(f"\nTotal execution time: {total_time:.2f}s")

0 commit comments

Comments
 (0)