|
| 1 | +from types import SimpleNamespace |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +import torch |
| 5 | + |
| 6 | +from torch_geometric.testing import withPackage |
| 7 | +from torch_geometric.utils import from_relbench |
| 8 | + |
| 9 | + |
| 10 | +def _mock_table( |
| 11 | + df: pd.DataFrame, |
| 12 | + fkey_col_to_pkey_table: dict, |
| 13 | + pkey_col: str = None, |
| 14 | + time_col: str = None, |
| 15 | +) -> SimpleNamespace: |
| 16 | + """Create a mock object that duck-types relbench.base.Table.""" |
| 17 | + return SimpleNamespace( |
| 18 | + df=df, |
| 19 | + fkey_col_to_pkey_table=fkey_col_to_pkey_table, |
| 20 | + pkey_col=pkey_col, |
| 21 | + time_col=time_col, |
| 22 | + ) |
| 23 | + |
| 24 | + |
| 25 | +def _mock_database(table_dict: dict) -> SimpleNamespace: |
| 26 | + """Create a mock object that duck-types relbench.base.Database.""" |
| 27 | + return SimpleNamespace(table_dict=table_dict) |
| 28 | + |
| 29 | + |
| 30 | +def test_from_relbench(): |
| 31 | + df_users = pd.DataFrame({ |
| 32 | + 'id': [0, 1, 2], |
| 33 | + 'age': [25, 30, 35], |
| 34 | + 'score': [1.0, 2.0, 3.0], |
| 35 | + }) |
| 36 | + df_posts = pd.DataFrame({ |
| 37 | + 'id': [0, 1, 2, 3], |
| 38 | + 'user_id': [0, 1, 0, 2], |
| 39 | + 'length': [100, 200, 150, 300], |
| 40 | + }) |
| 41 | + |
| 42 | + users = _mock_table( |
| 43 | + df=df_users, |
| 44 | + fkey_col_to_pkey_table={}, |
| 45 | + pkey_col='id', |
| 46 | + ) |
| 47 | + posts = _mock_table( |
| 48 | + df=df_posts, |
| 49 | + fkey_col_to_pkey_table={'user_id': 'users'}, |
| 50 | + pkey_col='id', |
| 51 | + ) |
| 52 | + |
| 53 | + db = _mock_database(table_dict={'users': users, 'posts': posts}) |
| 54 | + data = from_relbench(db) |
| 55 | + |
| 56 | + # Verify node types: |
| 57 | + assert 'users' in data.node_types |
| 58 | + assert 'posts' in data.node_types |
| 59 | + |
| 60 | + # Verify node counts: |
| 61 | + assert data['users'].num_nodes == 3 |
| 62 | + assert data['posts'].num_nodes == 4 |
| 63 | + |
| 64 | + # Verify numeric features were extracted: |
| 65 | + assert data['users'].x is not None |
| 66 | + assert data['users'].x.size() == (3, 2) # age, score |
| 67 | + assert data['posts'].x is not None |
| 68 | + assert data['posts'].x.size() == (4, 1) # length |
| 69 | + |
| 70 | + # Verify feature values: |
| 71 | + assert torch.allclose( |
| 72 | + data['users'].x, |
| 73 | + torch.tensor([[25, 1.0], [30, 2.0], [35, 3.0]]), |
| 74 | + ) |
| 75 | + |
| 76 | + # Verify edge types (bidirectional fkey edges): |
| 77 | + edge_types = data.edge_types |
| 78 | + assert ('posts', 'f2p_user_id', 'users') in edge_types |
| 79 | + assert ('users', 'rev_f2p_user_id', 'posts') in edge_types |
| 80 | + |
| 81 | + # Verify edge index shapes (4 posts, each referencing a user): |
| 82 | + fwd = data['posts', 'f2p_user_id', 'users'].edge_index |
| 83 | + rev = data['users', 'rev_f2p_user_id', 'posts'].edge_index |
| 84 | + assert fwd.size() == (2, 4) |
| 85 | + assert rev.size() == (2, 4) |
| 86 | + |
| 87 | + |
| 88 | +def test_from_relbench_dangling_fkeys(): |
| 89 | + """Test that dangling (NaN) foreign keys are filtered out.""" |
| 90 | + df_users = pd.DataFrame({'id': [0, 1]}) |
| 91 | + df_posts = pd.DataFrame({ |
| 92 | + 'id': [0, 1, 2], |
| 93 | + 'user_id': |
| 94 | + pd.array([0, None, 1], dtype=pd.Int64Dtype()), |
| 95 | + }) |
| 96 | + |
| 97 | + users = _mock_table( |
| 98 | + df=df_users, |
| 99 | + fkey_col_to_pkey_table={}, |
| 100 | + pkey_col='id', |
| 101 | + ) |
| 102 | + posts = _mock_table( |
| 103 | + df=df_posts, |
| 104 | + fkey_col_to_pkey_table={'user_id': 'users'}, |
| 105 | + pkey_col='id', |
| 106 | + ) |
| 107 | + |
| 108 | + db = _mock_database(table_dict={'users': users, 'posts': posts}) |
| 109 | + data = from_relbench(db) |
| 110 | + |
| 111 | + # Only 2 out of 3 posts have valid foreign keys: |
| 112 | + fwd = data['posts', 'f2p_user_id', 'users'].edge_index |
| 113 | + assert fwd.size() == (2, 2) |
| 114 | + |
| 115 | + |
| 116 | +def test_from_relbench_time_column(): |
| 117 | + """Test that time columns are correctly converted.""" |
| 118 | + df = pd.DataFrame({ |
| 119 | + 'id': [0, 1, 2], |
| 120 | + 'ts': |
| 121 | + pd.to_datetime(['2024-01-01', '2024-01-02', '2024-01-03']), |
| 122 | + 'val': [10, 20, 30], |
| 123 | + }) |
| 124 | + |
| 125 | + events = _mock_table( |
| 126 | + df=df, |
| 127 | + fkey_col_to_pkey_table={}, |
| 128 | + pkey_col='id', |
| 129 | + time_col='ts', |
| 130 | + ) |
| 131 | + |
| 132 | + db = _mock_database(table_dict={'events': events}) |
| 133 | + data = from_relbench(db) |
| 134 | + |
| 135 | + assert data['events'].num_nodes == 3 |
| 136 | + assert data['events'].time is not None |
| 137 | + assert data['events'].time.size() == (3, ) |
| 138 | + # Time column should not appear in features: |
| 139 | + assert data['events'].x.size() == (3, 1) # only 'val' |
| 140 | + |
| 141 | + |
| 142 | +def test_from_relbench_no_features(): |
| 143 | + """Test tables with only pkey/fkey columns and no numeric features.""" |
| 144 | + df = pd.DataFrame({ |
| 145 | + 'id': [0, 1, 2], |
| 146 | + 'name': ['a', 'b', 'c'], # Non-numeric, should be excluded |
| 147 | + }) |
| 148 | + |
| 149 | + items = _mock_table( |
| 150 | + df=df, |
| 151 | + fkey_col_to_pkey_table={}, |
| 152 | + pkey_col='id', |
| 153 | + ) |
| 154 | + |
| 155 | + db = _mock_database(table_dict={'items': items}) |
| 156 | + data = from_relbench(db) |
| 157 | + |
| 158 | + assert data['items'].num_nodes == 3 |
| 159 | + # No numeric feature columns (name is string, id is pkey): |
| 160 | + assert not hasattr(data['items'], 'x') or data['items'].x is None |
| 161 | + |
| 162 | + |
| 163 | +@withPackage('relbench') |
| 164 | +def test_from_relbench_with_relbench(): |
| 165 | + """Integration test using actual relbench objects.""" |
| 166 | + from relbench.base import Database, Table |
| 167 | + |
| 168 | + df_users = pd.DataFrame({ |
| 169 | + 'id': [0, 1, 2], |
| 170 | + 'age': [25, 30, 35], |
| 171 | + }) |
| 172 | + df_posts = pd.DataFrame({ |
| 173 | + 'id': [0, 1, 2], |
| 174 | + 'user_id': [0, 1, 0], |
| 175 | + 'score': [10, 20, 30], |
| 176 | + }) |
| 177 | + |
| 178 | + users = Table( |
| 179 | + df=df_users, |
| 180 | + fkey_col_to_pkey_table={}, |
| 181 | + pkey_col='id', |
| 182 | + ) |
| 183 | + posts = Table( |
| 184 | + df=df_posts, |
| 185 | + fkey_col_to_pkey_table={'user_id': 'users'}, |
| 186 | + pkey_col='id', |
| 187 | + ) |
| 188 | + |
| 189 | + db = Database(table_dict={'users': users, 'posts': posts}) |
| 190 | + data = from_relbench(db) |
| 191 | + |
| 192 | + assert 'users' in data.node_types |
| 193 | + assert 'posts' in data.node_types |
| 194 | + assert data['users'].num_nodes == 3 |
| 195 | + assert data['posts'].num_nodes == 3 |
0 commit comments