Skip to content

Commit 9be023e

Browse files
committed
feat: add unified vector ground-truth generation for Stack Overflow datasets
1 parent 7d6fa45 commit 9be023e

2 files changed

Lines changed: 237 additions & 1 deletion

File tree

bindings/python/docs/examples/download_data.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ python download_data.py stackoverflow-small --vector-model all-MiniLM-L6-v2
3030
python download_data.py stackoverflow-small --vector-batch-size 128
3131
python download_data.py stackoverflow-small --vector-shard-size 100000
3232
python download_data.py stackoverflow-small --vector-max-rows 50000
33+
python download_data.py stackoverflow-small --vector-gt-queries 1000 --vector-gt-topk 50
3334
python download_data.py stackoverflow-large
3435
python download_data.py stackoverflow-full
3536
python download_data.py msmarco-1m
@@ -39,6 +40,8 @@ python download_data.py msmarco-1m
3940

4041
- **MovieLens NULL injection** is enabled by default (use `--no-nulls` to skip).
4142
- **Stack Exchange vectors** are generated by default for questions, answers, and comments.
43+
The script also builds a combined `all` corpus and computes exact ground truth
44+
(`.gt.jsonl`) from sampled queries, MSMARCO-style.
4245
Use `--no-vectors` to skip.
4346
- **MSMARCO** downloads parquet shards and converts them to vector shards with a ground-truth file.
4447

@@ -55,13 +58,17 @@ Install only what you need for the datasets you plan to download:
5558
- MovieLens: `examples/data/movielens-<size>/`
5659
- Stack Exchange: `examples/data/stackoverflow-<size>/`
5760
- Stack Exchange vectors: `examples/data/stackoverflow-<size>/vectors/`
61+
- Includes per-corpus files (`questions`, `answers`, `comments`) and combined `all` files
5862
- MSMARCO: `examples/data/MSMARCO-<size>/`
5963

6064
## Formats & Schemas
6165

6266
- **MovieLens**: CSV files, no schema file generated.
6367
- **Stack Exchange**: XML files, no schema file generated.
64-
- **Stack Exchange vectors**: binary vector shards (`.f32`) plus `.meta.json` and `.ids.jsonl`.
68+
- **Stack Exchange vectors**: binary vector shards (`.f32`) plus `.meta.json`, `.ids.jsonl`, and combined `.gt.jsonl`.
69+
- Per-corpus outputs: `stackoverflow-<size>-questions|answers|comments.{meta.json,ids.jsonl,shard*.f32}`
70+
- Combined outputs: `stackoverflow-<size>-all.meta.json`, `stackoverflow-<size>-all.ids.jsonl`, `stackoverflow-<size>-all.gt.jsonl`
71+
- Defaults for combined GT: `1000` sampled queries, `topk=50` (configurable via `--vector-gt-queries` and `--vector-gt-topk`)
6572
- Vectors are 384-D, L2-normalized (all-MiniLM-L6-v2).
6673
- **MSMARCO**: binary vector shards (`.f32`) plus `.meta.json` and `.gt.jsonl`.
6774
- Vectors are 1024‑D; 1M/5M/10M indicate the number of vectors.

bindings/python/examples/download_data.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Memory-efficient streaming for large files
1313
- Smart sampling for fast verification (100K rows)
1414
- Stack Overflow vector conversion (requires sentence-transformers + torch)
15+
- Stack Overflow unified vector ground-truth generation (MSMARCO-style)
1516
1617
Available datasets:
1718
1. MovieLens (movie ratings, tags, genres):
@@ -1014,6 +1015,9 @@ def embed_stackoverflow_vectors(
10141015
shard_size: int = 100_000,
10151016
max_rows: int | None = None,
10161017
progress_every: int = 10_000,
1018+
gt_queries: int = 1000,
1019+
gt_topk: int = 50,
1020+
gt_chunk: int = 4096,
10171021
) -> None:
10181022
try:
10191023
import numpy as np
@@ -1204,6 +1208,217 @@ def answer_text(row: dict[str, str | None]) -> str:
12041208
["Text"],
12051209
)
12061210

1211+
corpus_names = ["questions", "answers", "comments"]
1212+
corpus_metas: list[dict[str, object]] = []
1213+
for name in corpus_names:
1214+
meta_path = out_dir / f"{dataset_name}-{name}.meta.json"
1215+
if not meta_path.exists():
1216+
raise FileNotFoundError(f"Missing vector metadata: {meta_path}")
1217+
corpus_metas.append(json.loads(meta_path.read_text(encoding="utf-8")))
1218+
1219+
dims = {
1220+
int(meta["dim"])
1221+
for meta in corpus_metas
1222+
if meta.get("dim") is not None and int(meta.get("count", 0)) > 0
1223+
}
1224+
if len(dims) > 1:
1225+
raise RuntimeError(
1226+
f"Mismatched vector dimensions across corpora: {sorted(dims)}"
1227+
)
1228+
combined_dim = next(iter(dims)) if dims else None
1229+
1230+
combined_shards: list[dict[str, object]] = []
1231+
combined_total = 0
1232+
for meta in corpus_metas:
1233+
source_corpus = str(meta["corpus"])
1234+
for shard in meta.get("shards", []):
1235+
shard_obj = dict(shard)
1236+
shard_path = out_dir / str(shard_obj["path"])
1237+
shard_count = int(shard_obj["count"])
1238+
combined_shards.append(
1239+
{
1240+
"path": str(shard_obj["path"]),
1241+
"path_obj": shard_path,
1242+
"count": shard_count,
1243+
"start": combined_total,
1244+
"source_corpus": source_corpus,
1245+
}
1246+
)
1247+
combined_total += shard_count
1248+
1249+
combined_ids_path = out_dir / f"{dataset_name}-all.ids.jsonl"
1250+
global_id = 0
1251+
with open(combined_ids_path, "w", encoding="utf-8") as fout:
1252+
for name in corpus_names:
1253+
source_ids = out_dir / f"{dataset_name}-{name}.ids.jsonl"
1254+
with open(source_ids, "r", encoding="utf-8") as fin:
1255+
for line in fin:
1256+
if not line.strip():
1257+
continue
1258+
obj = json.loads(line)
1259+
obj["source_corpus"] = name
1260+
obj["source_vector_id"] = obj.get("vector_id")
1261+
obj["vector_id"] = global_id
1262+
fout.write(json.dumps(obj) + "\n")
1263+
global_id += 1
1264+
1265+
if global_id != combined_total:
1266+
raise RuntimeError(
1267+
"Combined id count does not match combined vector count: "
1268+
f"ids={global_id}, vectors={combined_total}"
1269+
)
1270+
1271+
def build_gt_sharded(
1272+
*,
1273+
shards: list[dict[str, object]],
1274+
total_count: int,
1275+
dim: int,
1276+
gt_path: Path,
1277+
q_count: int,
1278+
topk: int,
1279+
) -> None:
1280+
import heapq
1281+
import mmap
1282+
1283+
print(f"[GT] building exact GT for {q_count} queries, k={topk}")
1284+
1285+
q_count = min(q_count, total_count)
1286+
rng = np.random.default_rng()
1287+
q_indices = rng.choice(total_count, size=q_count, replace=False)
1288+
1289+
queries = np.empty((q_count, dim), dtype=np.float32)
1290+
shard_map: dict[Path, list[tuple[int, int]]] = {}
1291+
1292+
for qi, gidx in enumerate(q_indices):
1293+
for shard in shards:
1294+
shard_start = int(shard["start"])
1295+
shard_count = int(shard["count"])
1296+
if shard_start <= gidx < shard_start + shard_count:
1297+
shard_map.setdefault(Path(shard["path_obj"]), []).append(
1298+
(qi, int(gidx - shard_start))
1299+
)
1300+
break
1301+
1302+
def close_memmap(mm: "np.memmap | None") -> None:
1303+
if mm is None:
1304+
return
1305+
mm.flush()
1306+
m = getattr(mm, "_mmap", None)
1307+
if m is not None:
1308+
try:
1309+
m.madvise(mmap.MADV_DONTNEED)
1310+
except Exception:
1311+
pass
1312+
m.close()
1313+
1314+
for shard in shards:
1315+
shard_path = Path(shard["path_obj"])
1316+
assigns = shard_map.get(shard_path)
1317+
if not assigns:
1318+
continue
1319+
mm = np.memmap(
1320+
shard_path,
1321+
dtype=np.float32,
1322+
mode="r",
1323+
shape=(int(shard["count"]), dim),
1324+
)
1325+
for qi, local_idx in assigns:
1326+
queries[qi] = mm[local_idx]
1327+
close_memmap(mm)
1328+
1329+
heaps = [[] for _ in range(q_count)]
1330+
1331+
for shard in shards:
1332+
shard_path = Path(shard["path_obj"])
1333+
print(f"[GT] scanning {shard_path.name}")
1334+
mm = np.memmap(
1335+
shard_path,
1336+
dtype=np.float32,
1337+
mode="r",
1338+
shape=(int(shard["count"]), dim),
1339+
)
1340+
for off in range(0, int(shard["count"]), gt_chunk):
1341+
block = mm[off : off + gt_chunk]
1342+
sims = block @ queries.T
1343+
for qi in range(q_count):
1344+
heap = heaps[qi]
1345+
col = sims[:, qi]
1346+
for i, score in enumerate(col):
1347+
doc_id = int(shard["start"]) + off + i
1348+
if len(heap) < topk:
1349+
heapq.heappush(heap, (float(score), doc_id))
1350+
else:
1351+
heapq.heappushpop(heap, (float(score), doc_id))
1352+
close_memmap(mm)
1353+
1354+
with open(gt_path, "w", encoding="utf-8") as f:
1355+
for qi, heap in enumerate(heaps):
1356+
heap.sort(reverse=True)
1357+
json.dump(
1358+
{
1359+
"query_id": int(q_indices[qi]),
1360+
"topk": [
1361+
{"doc_id": int(doc_id), "score": float(score)}
1362+
for score, doc_id in heap
1363+
],
1364+
},
1365+
f,
1366+
)
1367+
f.write("\n")
1368+
1369+
print(f"[GT] wrote {gt_path}")
1370+
1371+
gt_path = out_dir / f"{dataset_name}-all.gt.jsonl"
1372+
if combined_total <= 0:
1373+
print("[GT] skipping GT generation (no vectors found)")
1374+
elif combined_dim is None:
1375+
print("[GT] skipping GT generation (missing vector dimensions)")
1376+
else:
1377+
build_gt_sharded(
1378+
shards=combined_shards,
1379+
total_count=combined_total,
1380+
dim=int(combined_dim),
1381+
gt_path=gt_path,
1382+
q_count=gt_queries,
1383+
topk=gt_topk,
1384+
)
1385+
1386+
combined_meta_path = out_dir / f"{dataset_name}-all.meta.json"
1387+
combined_meta = {
1388+
"dataset": dataset_name,
1389+
"corpus": "all",
1390+
"source_corpora": corpus_names,
1391+
"model": model_name,
1392+
"device": device,
1393+
"dim": combined_dim,
1394+
"dtype": "float32",
1395+
"count": combined_total,
1396+
"shard_size": shard_size,
1397+
"shards": [
1398+
{
1399+
"path": str(shard["path"]),
1400+
"count": int(shard["count"]),
1401+
"start": int(shard["start"]),
1402+
"source_corpus": str(shard["source_corpus"]),
1403+
}
1404+
for shard in combined_shards
1405+
],
1406+
"ids_file": combined_ids_path.name,
1407+
"gt_file": gt_path.name,
1408+
"gt_queries": min(gt_queries, combined_total),
1409+
"gt_topk": gt_topk,
1410+
"max_seq_length": max_seq_len,
1411+
}
1412+
combined_meta_path.write_text(
1413+
json.dumps(combined_meta, indent=2),
1414+
encoding="utf-8",
1415+
)
1416+
print(
1417+
f"[VECTORS] all: {combined_total:,} vectors, "
1418+
f"{len(combined_shards)} shards, "
1419+
f"gt={gt_path.name}"
1420+
)
1421+
12071422

12081423
def download_tpch(scale_factor: int = 10) -> Path:
12091424
"""Generate TPC-H data using dbgen via Docker."""
@@ -2038,6 +2253,18 @@ def main():
20382253
default=None,
20392254
help="Optional max vectors per corpus (questions/answers/comments)",
20402255
)
2256+
parser.add_argument(
2257+
"--vector-gt-queries",
2258+
type=int,
2259+
default=1000,
2260+
help="Number of sampled queries for Stack Overflow GT (default: 1000)",
2261+
)
2262+
parser.add_argument(
2263+
"--vector-gt-topk",
2264+
type=int,
2265+
default=50,
2266+
help="Top-k neighbors per sampled query for Stack Overflow GT (default: 50)",
2267+
)
20412268
args = parser.parse_args()
20422269

20432270
print("=" * 70)
@@ -2220,6 +2447,8 @@ def main():
22202447
batch_size=args.vector_batch_size,
22212448
shard_size=args.vector_shard_size,
22222449
max_rows=args.vector_max_rows,
2450+
gt_queries=args.vector_gt_queries,
2451+
gt_topk=args.vector_gt_topk,
22232452
)
22242453
print()
22252454

0 commit comments

Comments
 (0)