|
1 | 1 | import sqlite3 |
2 | | -from helpers import exec, vec0_shadow_table_contents |
| 2 | +import struct |
| 3 | +import pytest |
| 4 | +from helpers import exec, vec0_shadow_table_contents, _f32 |
3 | 5 |
|
4 | 6 |
|
5 | 7 | def test_constructor_limit(db, snapshot): |
@@ -126,3 +128,198 @@ def test_knn(db, snapshot): |
126 | 128 | ) == snapshot(name="illegal KNN w/ aux") |
127 | 129 |
|
128 | 130 |
|
| 131 | +# ====================================================================== |
| 132 | +# Auxiliary columns with non-flat indexes |
| 133 | +# ====================================================================== |
| 134 | + |
| 135 | + |
| 136 | +def test_rescore_aux_shadow_tables(db, snapshot): |
| 137 | + """Rescore + aux column: verify shadow tables are created correctly.""" |
| 138 | + db.execute( |
| 139 | + "CREATE VIRTUAL TABLE t USING vec0(" |
| 140 | + " emb float[128] indexed by rescore(quantizer=bit)," |
| 141 | + " +label text," |
| 142 | + " +score float" |
| 143 | + ")" |
| 144 | + ) |
| 145 | + assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot( |
| 146 | + name="rescore aux shadow tables" |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +def test_rescore_aux_insert_knn(db, snapshot): |
| 151 | + """Insert with aux data, KNN should return aux column values.""" |
| 152 | + db.execute( |
| 153 | + "CREATE VIRTUAL TABLE t USING vec0(" |
| 154 | + " emb float[128] indexed by rescore(quantizer=bit)," |
| 155 | + " +label text" |
| 156 | + ")" |
| 157 | + ) |
| 158 | + import random |
| 159 | + random.seed(77) |
| 160 | + data = [ |
| 161 | + ("alpha", [random.gauss(0, 1) for _ in range(128)]), |
| 162 | + ("beta", [random.gauss(0, 1) for _ in range(128)]), |
| 163 | + ("gamma", [random.gauss(0, 1) for _ in range(128)]), |
| 164 | + ] |
| 165 | + for label, vec in data: |
| 166 | + db.execute( |
| 167 | + "INSERT INTO t(emb, label) VALUES (?, ?)", |
| 168 | + [_f32(vec), label], |
| 169 | + ) |
| 170 | + |
| 171 | + assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot( |
| 172 | + name="rescore aux select all" |
| 173 | + ) |
| 174 | + assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot( |
| 175 | + name="rescore aux shadow contents" |
| 176 | + ) |
| 177 | + |
| 178 | + # KNN should include aux column, "alpha" closest to its own vector |
| 179 | + rows = db.execute( |
| 180 | + "SELECT label, distance FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 3", |
| 181 | + [_f32(data[0][1])], |
| 182 | + ).fetchall() |
| 183 | + assert len(rows) == 3 |
| 184 | + assert rows[0][0] == "alpha" |
| 185 | + |
| 186 | + |
| 187 | +def test_rescore_aux_update(db): |
| 188 | + """UPDATE aux column on rescore table should work without affecting vectors.""" |
| 189 | + db.execute( |
| 190 | + "CREATE VIRTUAL TABLE t USING vec0(" |
| 191 | + " emb float[128] indexed by rescore(quantizer=bit)," |
| 192 | + " +label text" |
| 193 | + ")" |
| 194 | + ) |
| 195 | + import random |
| 196 | + random.seed(88) |
| 197 | + vec = [random.gauss(0, 1) for _ in range(128)] |
| 198 | + db.execute("INSERT INTO t(rowid, emb, label) VALUES (1, ?, 'original')", [_f32(vec)]) |
| 199 | + db.execute("UPDATE t SET label = 'updated' WHERE rowid = 1") |
| 200 | + |
| 201 | + assert db.execute("SELECT label FROM t WHERE rowid = 1").fetchone()[0] == "updated" |
| 202 | + |
| 203 | + # KNN still works with updated aux |
| 204 | + rows = db.execute( |
| 205 | + "SELECT rowid, label FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 1", |
| 206 | + [_f32(vec)], |
| 207 | + ).fetchall() |
| 208 | + assert rows[0][0] == 1 |
| 209 | + assert rows[0][1] == "updated" |
| 210 | + |
| 211 | + |
| 212 | +def test_rescore_aux_delete(db, snapshot): |
| 213 | + """DELETE should remove aux data from shadow table.""" |
| 214 | + db.execute( |
| 215 | + "CREATE VIRTUAL TABLE t USING vec0(" |
| 216 | + " emb float[128] indexed by rescore(quantizer=bit)," |
| 217 | + " +label text" |
| 218 | + ")" |
| 219 | + ) |
| 220 | + import random |
| 221 | + random.seed(99) |
| 222 | + for i in range(5): |
| 223 | + db.execute( |
| 224 | + "INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)", |
| 225 | + [i + 1, _f32([random.gauss(0, 1) for _ in range(128)]), f"item-{i+1}"], |
| 226 | + ) |
| 227 | + |
| 228 | + db.execute("DELETE FROM t WHERE rowid = 3") |
| 229 | + |
| 230 | + assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot( |
| 231 | + name="rescore aux after delete" |
| 232 | + ) |
| 233 | + assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot( |
| 234 | + name="rescore aux shadow after delete" |
| 235 | + ) |
| 236 | + |
| 237 | + |
| 238 | +def test_diskann_aux_shadow_tables(db, snapshot): |
| 239 | + """DiskANN + aux column: verify shadow tables are created correctly.""" |
| 240 | + db.execute(""" |
| 241 | + CREATE VIRTUAL TABLE t USING vec0( |
| 242 | + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8), |
| 243 | + +label text, |
| 244 | + +score float |
| 245 | + ) |
| 246 | + """) |
| 247 | + assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot( |
| 248 | + name="diskann aux shadow tables" |
| 249 | + ) |
| 250 | + |
| 251 | + |
| 252 | +def test_diskann_aux_insert_knn(db, snapshot): |
| 253 | + """DiskANN + aux: insert, KNN, verify aux values returned.""" |
| 254 | + db.execute(""" |
| 255 | + CREATE VIRTUAL TABLE t USING vec0( |
| 256 | + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8), |
| 257 | + +label text |
| 258 | + ) |
| 259 | + """) |
| 260 | + data = [ |
| 261 | + ("red", [1, 0, 0, 0, 0, 0, 0, 0]), |
| 262 | + ("green", [0, 1, 0, 0, 0, 0, 0, 0]), |
| 263 | + ("blue", [0, 0, 1, 0, 0, 0, 0, 0]), |
| 264 | + ] |
| 265 | + for label, vec in data: |
| 266 | + db.execute("INSERT INTO t(emb, label) VALUES (?, ?)", [_f32(vec), label]) |
| 267 | + |
| 268 | + assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot( |
| 269 | + name="diskann aux select all" |
| 270 | + ) |
| 271 | + assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot( |
| 272 | + name="diskann aux shadow contents" |
| 273 | + ) |
| 274 | + |
| 275 | + rows = db.execute( |
| 276 | + "SELECT label, distance FROM t WHERE emb MATCH ? AND k = 3", |
| 277 | + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], |
| 278 | + ).fetchall() |
| 279 | + assert len(rows) >= 1 |
| 280 | + assert rows[0][0] == "red" |
| 281 | + |
| 282 | + |
| 283 | +def test_diskann_aux_update_and_delete(db, snapshot): |
| 284 | + """DiskANN + aux: update aux column, delete row, verify cleanup.""" |
| 285 | + db.execute(""" |
| 286 | + CREATE VIRTUAL TABLE t USING vec0( |
| 287 | + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8), |
| 288 | + +label text |
| 289 | + ) |
| 290 | + """) |
| 291 | + for i in range(5): |
| 292 | + vec = [0.0] * 8 |
| 293 | + vec[i % 8] = 1.0 |
| 294 | + db.execute( |
| 295 | + "INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)", |
| 296 | + [i + 1, _f32(vec), f"item-{i+1}"], |
| 297 | + ) |
| 298 | + |
| 299 | + db.execute("UPDATE t SET label = 'UPDATED' WHERE rowid = 2") |
| 300 | + db.execute("DELETE FROM t WHERE rowid = 3") |
| 301 | + |
| 302 | + assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot( |
| 303 | + name="diskann aux after update+delete" |
| 304 | + ) |
| 305 | + assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot( |
| 306 | + name="diskann aux shadow after update+delete" |
| 307 | + ) |
| 308 | + |
| 309 | + |
| 310 | +def test_diskann_aux_drop_cleans_all(db): |
| 311 | + """DROP TABLE should remove aux shadow table too.""" |
| 312 | + db.execute(""" |
| 313 | + CREATE VIRTUAL TABLE t USING vec0( |
| 314 | + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8), |
| 315 | + +label text |
| 316 | + ) |
| 317 | + """) |
| 318 | + db.execute("INSERT INTO t(emb, label) VALUES (?, 'test')", [_f32([1]*8)]) |
| 319 | + db.execute("DROP TABLE t") |
| 320 | + |
| 321 | + tables = [r[0] for r in db.execute( |
| 322 | + "SELECT name FROM sqlite_master WHERE name LIKE 't_%'" |
| 323 | + ).fetchall()] |
| 324 | + assert "t_auxiliary" not in tables |
| 325 | + |
0 commit comments