|
12 | 12 | - Memory-efficient streaming for large files |
13 | 13 | - Smart sampling for fast verification (100K rows) |
14 | 14 | - Stack Overflow vector conversion (requires sentence-transformers + torch) |
| 15 | +- Stack Overflow unified vector ground-truth generation (MSMARCO-style) |
15 | 16 |
|
16 | 17 | Available datasets: |
17 | 18 | 1. MovieLens (movie ratings, tags, genres): |
@@ -1014,6 +1015,9 @@ def embed_stackoverflow_vectors( |
1014 | 1015 | shard_size: int = 100_000, |
1015 | 1016 | max_rows: int | None = None, |
1016 | 1017 | progress_every: int = 10_000, |
| 1018 | + gt_queries: int = 1000, |
| 1019 | + gt_topk: int = 50, |
| 1020 | + gt_chunk: int = 4096, |
1017 | 1021 | ) -> None: |
1018 | 1022 | try: |
1019 | 1023 | import numpy as np |
@@ -1204,6 +1208,217 @@ def answer_text(row: dict[str, str | None]) -> str: |
1204 | 1208 | ["Text"], |
1205 | 1209 | ) |
1206 | 1210 |
|
| 1211 | + corpus_names = ["questions", "answers", "comments"] |
| 1212 | + corpus_metas: list[dict[str, object]] = [] |
| 1213 | + for name in corpus_names: |
| 1214 | + meta_path = out_dir / f"{dataset_name}-{name}.meta.json" |
| 1215 | + if not meta_path.exists(): |
| 1216 | + raise FileNotFoundError(f"Missing vector metadata: {meta_path}") |
| 1217 | + corpus_metas.append(json.loads(meta_path.read_text(encoding="utf-8"))) |
| 1218 | + |
| 1219 | + dims = { |
| 1220 | + int(meta["dim"]) |
| 1221 | + for meta in corpus_metas |
| 1222 | + if meta.get("dim") is not None and int(meta.get("count", 0)) > 0 |
| 1223 | + } |
| 1224 | + if len(dims) > 1: |
| 1225 | + raise RuntimeError( |
| 1226 | + f"Mismatched vector dimensions across corpora: {sorted(dims)}" |
| 1227 | + ) |
| 1228 | + combined_dim = next(iter(dims)) if dims else None |
| 1229 | + |
| 1230 | + combined_shards: list[dict[str, object]] = [] |
| 1231 | + combined_total = 0 |
| 1232 | + for meta in corpus_metas: |
| 1233 | + source_corpus = str(meta["corpus"]) |
| 1234 | + for shard in meta.get("shards", []): |
| 1235 | + shard_obj = dict(shard) |
| 1236 | + shard_path = out_dir / str(shard_obj["path"]) |
| 1237 | + shard_count = int(shard_obj["count"]) |
| 1238 | + combined_shards.append( |
| 1239 | + { |
| 1240 | + "path": str(shard_obj["path"]), |
| 1241 | + "path_obj": shard_path, |
| 1242 | + "count": shard_count, |
| 1243 | + "start": combined_total, |
| 1244 | + "source_corpus": source_corpus, |
| 1245 | + } |
| 1246 | + ) |
| 1247 | + combined_total += shard_count |
| 1248 | + |
| 1249 | + combined_ids_path = out_dir / f"{dataset_name}-all.ids.jsonl" |
| 1250 | + global_id = 0 |
| 1251 | + with open(combined_ids_path, "w", encoding="utf-8") as fout: |
| 1252 | + for name in corpus_names: |
| 1253 | + source_ids = out_dir / f"{dataset_name}-{name}.ids.jsonl" |
| 1254 | + with open(source_ids, "r", encoding="utf-8") as fin: |
| 1255 | + for line in fin: |
| 1256 | + if not line.strip(): |
| 1257 | + continue |
| 1258 | + obj = json.loads(line) |
| 1259 | + obj["source_corpus"] = name |
| 1260 | + obj["source_vector_id"] = obj.get("vector_id") |
| 1261 | + obj["vector_id"] = global_id |
| 1262 | + fout.write(json.dumps(obj) + "\n") |
| 1263 | + global_id += 1 |
| 1264 | + |
| 1265 | + if global_id != combined_total: |
| 1266 | + raise RuntimeError( |
| 1267 | + "Combined id count does not match combined vector count: " |
| 1268 | + f"ids={global_id}, vectors={combined_total}" |
| 1269 | + ) |
| 1270 | + |
| 1271 | + def build_gt_sharded( |
| 1272 | + *, |
| 1273 | + shards: list[dict[str, object]], |
| 1274 | + total_count: int, |
| 1275 | + dim: int, |
| 1276 | + gt_path: Path, |
| 1277 | + q_count: int, |
| 1278 | + topk: int, |
| 1279 | + ) -> None: |
| 1280 | + import heapq |
| 1281 | + import mmap |
| 1282 | + |
| 1283 | + print(f"[GT] building exact GT for {q_count} queries, k={topk}") |
| 1284 | + |
| 1285 | + q_count = min(q_count, total_count) |
| 1286 | + rng = np.random.default_rng() |
| 1287 | + q_indices = rng.choice(total_count, size=q_count, replace=False) |
| 1288 | + |
| 1289 | + queries = np.empty((q_count, dim), dtype=np.float32) |
| 1290 | + shard_map: dict[Path, list[tuple[int, int]]] = {} |
| 1291 | + |
| 1292 | + for qi, gidx in enumerate(q_indices): |
| 1293 | + for shard in shards: |
| 1294 | + shard_start = int(shard["start"]) |
| 1295 | + shard_count = int(shard["count"]) |
| 1296 | + if shard_start <= gidx < shard_start + shard_count: |
| 1297 | + shard_map.setdefault(Path(shard["path_obj"]), []).append( |
| 1298 | + (qi, int(gidx - shard_start)) |
| 1299 | + ) |
| 1300 | + break |
| 1301 | + |
| 1302 | + def close_memmap(mm: "np.memmap | None") -> None: |
| 1303 | + if mm is None: |
| 1304 | + return |
| 1305 | + mm.flush() |
| 1306 | + m = getattr(mm, "_mmap", None) |
| 1307 | + if m is not None: |
| 1308 | + try: |
| 1309 | + m.madvise(mmap.MADV_DONTNEED) |
| 1310 | + except Exception: |
| 1311 | + pass |
| 1312 | + m.close() |
| 1313 | + |
| 1314 | + for shard in shards: |
| 1315 | + shard_path = Path(shard["path_obj"]) |
| 1316 | + assigns = shard_map.get(shard_path) |
| 1317 | + if not assigns: |
| 1318 | + continue |
| 1319 | + mm = np.memmap( |
| 1320 | + shard_path, |
| 1321 | + dtype=np.float32, |
| 1322 | + mode="r", |
| 1323 | + shape=(int(shard["count"]), dim), |
| 1324 | + ) |
| 1325 | + for qi, local_idx in assigns: |
| 1326 | + queries[qi] = mm[local_idx] |
| 1327 | + close_memmap(mm) |
| 1328 | + |
| 1329 | + heaps = [[] for _ in range(q_count)] |
| 1330 | + |
| 1331 | + for shard in shards: |
| 1332 | + shard_path = Path(shard["path_obj"]) |
| 1333 | + print(f"[GT] scanning {shard_path.name}") |
| 1334 | + mm = np.memmap( |
| 1335 | + shard_path, |
| 1336 | + dtype=np.float32, |
| 1337 | + mode="r", |
| 1338 | + shape=(int(shard["count"]), dim), |
| 1339 | + ) |
| 1340 | + for off in range(0, int(shard["count"]), gt_chunk): |
| 1341 | + block = mm[off : off + gt_chunk] |
| 1342 | + sims = block @ queries.T |
| 1343 | + for qi in range(q_count): |
| 1344 | + heap = heaps[qi] |
| 1345 | + col = sims[:, qi] |
| 1346 | + for i, score in enumerate(col): |
| 1347 | + doc_id = int(shard["start"]) + off + i |
| 1348 | + if len(heap) < topk: |
| 1349 | + heapq.heappush(heap, (float(score), doc_id)) |
| 1350 | + else: |
| 1351 | + heapq.heappushpop(heap, (float(score), doc_id)) |
| 1352 | + close_memmap(mm) |
| 1353 | + |
| 1354 | + with open(gt_path, "w", encoding="utf-8") as f: |
| 1355 | + for qi, heap in enumerate(heaps): |
| 1356 | + heap.sort(reverse=True) |
| 1357 | + json.dump( |
| 1358 | + { |
| 1359 | + "query_id": int(q_indices[qi]), |
| 1360 | + "topk": [ |
| 1361 | + {"doc_id": int(doc_id), "score": float(score)} |
| 1362 | + for score, doc_id in heap |
| 1363 | + ], |
| 1364 | + }, |
| 1365 | + f, |
| 1366 | + ) |
| 1367 | + f.write("\n") |
| 1368 | + |
| 1369 | + print(f"[GT] wrote {gt_path}") |
| 1370 | + |
| 1371 | + gt_path = out_dir / f"{dataset_name}-all.gt.jsonl" |
| 1372 | + if combined_total <= 0: |
| 1373 | + print("[GT] skipping GT generation (no vectors found)") |
| 1374 | + elif combined_dim is None: |
| 1375 | + print("[GT] skipping GT generation (missing vector dimensions)") |
| 1376 | + else: |
| 1377 | + build_gt_sharded( |
| 1378 | + shards=combined_shards, |
| 1379 | + total_count=combined_total, |
| 1380 | + dim=int(combined_dim), |
| 1381 | + gt_path=gt_path, |
| 1382 | + q_count=gt_queries, |
| 1383 | + topk=gt_topk, |
| 1384 | + ) |
| 1385 | + |
| 1386 | + combined_meta_path = out_dir / f"{dataset_name}-all.meta.json" |
| 1387 | + combined_meta = { |
| 1388 | + "dataset": dataset_name, |
| 1389 | + "corpus": "all", |
| 1390 | + "source_corpora": corpus_names, |
| 1391 | + "model": model_name, |
| 1392 | + "device": device, |
| 1393 | + "dim": combined_dim, |
| 1394 | + "dtype": "float32", |
| 1395 | + "count": combined_total, |
| 1396 | + "shard_size": shard_size, |
| 1397 | + "shards": [ |
| 1398 | + { |
| 1399 | + "path": str(shard["path"]), |
| 1400 | + "count": int(shard["count"]), |
| 1401 | + "start": int(shard["start"]), |
| 1402 | + "source_corpus": str(shard["source_corpus"]), |
| 1403 | + } |
| 1404 | + for shard in combined_shards |
| 1405 | + ], |
| 1406 | + "ids_file": combined_ids_path.name, |
| 1407 | + "gt_file": gt_path.name, |
| 1408 | + "gt_queries": min(gt_queries, combined_total), |
| 1409 | + "gt_topk": gt_topk, |
| 1410 | + "max_seq_length": max_seq_len, |
| 1411 | + } |
| 1412 | + combined_meta_path.write_text( |
| 1413 | + json.dumps(combined_meta, indent=2), |
| 1414 | + encoding="utf-8", |
| 1415 | + ) |
| 1416 | + print( |
| 1417 | + f"[VECTORS] all: {combined_total:,} vectors, " |
| 1418 | + f"{len(combined_shards)} shards, " |
| 1419 | + f"gt={gt_path.name}" |
| 1420 | + ) |
| 1421 | + |
1207 | 1422 |
|
1208 | 1423 | def download_tpch(scale_factor: int = 10) -> Path: |
1209 | 1424 | """Generate TPC-H data using dbgen via Docker.""" |
@@ -2038,6 +2253,18 @@ def main(): |
2038 | 2253 | default=None, |
2039 | 2254 | help="Optional max vectors per corpus (questions/answers/comments)", |
2040 | 2255 | ) |
| 2256 | + parser.add_argument( |
| 2257 | + "--vector-gt-queries", |
| 2258 | + type=int, |
| 2259 | + default=1000, |
| 2260 | + help="Number of sampled queries for Stack Overflow GT (default: 1000)", |
| 2261 | + ) |
| 2262 | + parser.add_argument( |
| 2263 | + "--vector-gt-topk", |
| 2264 | + type=int, |
| 2265 | + default=50, |
| 2266 | + help="Top-k neighbors per sampled query for Stack Overflow GT (default: 50)", |
| 2267 | + ) |
2041 | 2268 | args = parser.parse_args() |
2042 | 2269 |
|
2043 | 2270 | print("=" * 70) |
@@ -2220,6 +2447,8 @@ def main(): |
2220 | 2447 | batch_size=args.vector_batch_size, |
2221 | 2448 | shard_size=args.vector_shard_size, |
2222 | 2449 | max_rows=args.vector_max_rows, |
| 2450 | + gt_queries=args.vector_gt_queries, |
| 2451 | + gt_topk=args.vector_gt_topk, |
2223 | 2452 | ) |
2224 | 2453 | print() |
2225 | 2454 |
|
|
0 commit comments