Skip to content

Commit 5bedfca

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 977147a commit 5bedfca

5 files changed

Lines changed: 172 additions & 204 deletions

File tree

examples/llm/relbench_example.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""RelBench integration example for PyTorch Geometric."""
22

33
import argparse
4+
45
from torch_geometric.utils.relbench import create_relbench_hetero_data
56

67

@@ -18,8 +19,8 @@ def main():
1819

1920
try:
2021
print(f"Loading RelBench dataset: {args.dataset}")
21-
hetero_data = create_relbench_hetero_data(
22-
args.dataset, sample_size=args.sample_size)
22+
hetero_data = create_relbench_hetero_data(args.dataset,
23+
sample_size=args.sample_size)
2324
print("Dataset loaded successfully")
2425

2526
print("HeteroData Summary:")

examples/relbench/02_train_rgcn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""
2-
RelBench R-GCN Training Example.
1+
"""RelBench R-GCN Training Example.
32
43
This example demonstrates how to train an R-GCN model on RelBench data
54
for data warehouse lineage prediction tasks.

torch_geometric/datasets/relbench.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,7 @@ def _add_warehouse_labels(self, node_store: Any, num_nodes: int,
258258
create_silo_labels: bool = True,
259259
create_anomaly_labels: bool = True,
260260
use_dummy_fallback: bool = False) -> None:
261-
"""
262-
Add warehouse task labels with 'Ready-for-Real-Data' pattern.
261+
"""Add warehouse task labels with 'Ready-for-Real-Data' pattern.
263262
264263
Precedence order: Real Data > Structural Inference > Dummy
265264
Fallback > None
@@ -274,7 +273,6 @@ def _add_warehouse_labels(self, node_store: Any, num_nodes: int,
274273
create_anomaly_labels: Whether to create anomaly detection labels
275274
use_dummy_fallback: Whether to use dummy data as last resort
276275
"""
277-
278276
# ETL Lineage Labels
279277
if create_lineage_labels:
280278
lineage_labels = self._get_lineage_labels(table_name, db,
@@ -299,8 +297,8 @@ def _get_lineage_labels(
299297
self, table_name: Optional[str], db: Any, num_nodes: int,
300298
use_dummy_fallback: bool) -> Optional[torch.Tensor]:
301299
"""Get ETL lineage labels with precedence: real > inferred >
302-
dummy > None."""
303-
300+
dummy > None.
301+
"""
304302
# Method 1: Check for real lineage data
305303
if self._has_real_lineage(db, table_name):
306304
return self._load_real_lineage(db, table_name)
@@ -326,8 +324,8 @@ def _get_silo_labels(self, table_name: Optional[str], db: Any,
326324
num_nodes: int,
327325
use_dummy_fallback: bool) -> Optional[torch.Tensor]:
328326
"""Get silo detection labels with precedence: real > inferred >
329-
dummy > None."""
330-
327+
dummy > None.
328+
"""
331329
# Method 1: Check for real silo data
332330
if self._has_real_silo_data(db, table_name):
333331
return self._load_real_silo_labels(db, table_name)
@@ -341,8 +339,8 @@ def _get_anomaly_labels(
341339
self, table_name: Optional[str], db: Any, num_nodes: int,
342340
use_dummy_fallback: bool) -> Optional[torch.Tensor]:
343341
"""Get anomaly detection labels with precedence: real > inferred >
344-
dummy > None."""
345-
342+
dummy > None.
343+
"""
346344
# Method 1: Check for real anomaly data
347345
if self._has_real_anomaly_data(db, table_name):
348346
return self._load_real_anomaly_labels(db, table_name)
@@ -366,8 +364,7 @@ def _get_anomaly_labels(
366364

367365
# Real data checking methods
368366
def _has_real_lineage(self, db: Any, table_name: Optional[str]) -> bool:
369-
"""
370-
Check if real ETL lineage data is available.
367+
"""Check if real ETL lineage data is available.
371368
372369
Args:
373370
db: RelBench database object
@@ -381,8 +378,7 @@ def _has_real_lineage(self, db: Any, table_name: Optional[str]) -> bool:
381378
and 'etl_stages' in db.lineage_metadata[table_name])
382379

383380
def _has_real_silo_data(self, db: Any, table_name: Optional[str]) -> bool:
384-
"""
385-
Check if real silo detection data is available.
381+
"""Check if real silo detection data is available.
386382
387383
Args:
388384
db: RelBench database object
@@ -396,8 +392,7 @@ def _has_real_silo_data(self, db: Any, table_name: Optional[str]) -> bool:
396392

397393
def _has_real_anomaly_data(self, db: Any,
398394
table_name: Optional[str]) -> bool:
399-
"""
400-
Check if real anomaly detection data is available.
395+
"""Check if real anomaly detection data is available.
401396
402397
Args:
403398
db: RelBench database object
@@ -481,7 +476,6 @@ def _infer_lineage_from_structure(self, table_name: str,
481476
def _infer_silo_from_connectivity(self, table_name: str, db: Any,
482477
num_nodes: int) -> torch.Tensor:
483478
"""Infer silo detection labels from table connectivity."""
484-
485479
# Count connections to other tables
486480
connections = 0
487481

@@ -524,9 +518,8 @@ def _infer_anomalies_from_statistics(self, table_name: str,
524518

525519
if IQR > 0: # Avoid division by zero
526520
# Mark outliers as anomalies
527-
outlier_mask = ((values <
528-
(Q1 - 1.5 * IQR)) | (values >
529-
(Q3 + 1.5 * IQR)))
521+
outlier_mask = ((values < (Q1 - 1.5 * IQR)) |
522+
(values > (Q3 + 1.5 * IQR)))
530523

531524
# Update anomaly labels for outlier rows
532525
outlier_indices = table_df[col].index[table_df[col].isin(

torch_geometric/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from ._train_test_split_edges import train_test_split_edges
6060
from .influence import total_influence
6161

62-
6362
__all__ = [
6463
'scatter',
6564
'group_argsort',
@@ -154,7 +153,6 @@
154153
'total_influence',
155154
]
156155

157-
158156
# `structured_negative_sampling_feasible` is a long name and thus destroys the
159157
# documentation rendering. We remove it for now from the documentation:
160158
classes = copy.copy(__all__)

0 commit comments

Comments
 (0)