|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import argparse |
| 3 | +import os |
| 4 | +import signal |
| 5 | +import subprocess |
| 6 | +import time |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | + |
| 10 | +ROOT = Path(__file__).resolve().parents[1] |
| 11 | +SPLASH = (ROOT / "../Splash").resolve() |
| 12 | +FAISS_SRC = ROOT / "workloads/faiss" |
| 13 | +FAISS_BUILD = ROOT / "build/faiss_cpu" |
| 14 | +BENCH_SRC = ROOT / "tools/faiss_splash/faiss_cpu_splash.cpp" |
| 15 | +BENCH_BIN = FAISS_BUILD / "faiss_cpu_splash" |
| 16 | +ARTIFACT = ROOT / "artifact/faiss_splash" |
| 17 | + |
| 18 | + |
| 19 | +def run(cmd, **kwargs): |
| 20 | + print("+", " ".join(str(x) for x in cmd), flush=True) |
| 21 | + subprocess.run(cmd, check=True, **kwargs) |
| 22 | + |
| 23 | + |
| 24 | +def configure_and_build_faiss(jobs): |
| 25 | + FAISS_BUILD.mkdir(parents=True, exist_ok=True) |
| 26 | + cache = FAISS_BUILD / "CMakeCache.txt" |
| 27 | + if not cache.exists(): |
| 28 | + run([ |
| 29 | + "cmake", "-S", str(FAISS_SRC), "-B", str(FAISS_BUILD), |
| 30 | + "-DCMAKE_BUILD_TYPE=Release", |
| 31 | + "-DFAISS_ENABLE_GPU=OFF", |
| 32 | + "-DFAISS_ENABLE_PYTHON=OFF", |
| 33 | + "-DFAISS_ENABLE_TESTS=OFF", |
| 34 | + "-DBUILD_TESTING=OFF", |
| 35 | + "-DFAISS_OPT_LEVEL=generic", |
| 36 | + "-DBLA_VENDOR=OpenBLAS", |
| 37 | + ]) |
| 38 | + run(["cmake", "--build", str(FAISS_BUILD), "--target", "faiss", "-j", str(jobs)]) |
| 39 | + |
| 40 | + |
| 41 | +def find_one(patterns): |
| 42 | + for pattern in patterns: |
| 43 | + matches = sorted(FAISS_BUILD.glob(pattern)) |
| 44 | + if matches: |
| 45 | + return matches[0] |
| 46 | + raise FileNotFoundError(f"no file matched {patterns}") |
| 47 | + |
| 48 | + |
| 49 | +def build_benchmark(): |
| 50 | + libfaiss = find_one(["faiss/libfaiss.a", "faiss/libfaiss.so", "**/libfaiss.a", "**/libfaiss.so"]) |
| 51 | + splash_backend = SPLASH / "build/libcxl_backend.a" |
| 52 | + if not splash_backend.exists(): |
| 53 | + raise FileNotFoundError(f"missing {splash_backend}; build Splash first") |
| 54 | + run([ |
| 55 | + "g++", "-std=c++17", "-O3", "-fopenmp", |
| 56 | + "-I", str(FAISS_SRC), |
| 57 | + "-I", str(FAISS_BUILD), |
| 58 | + "-I", str(SPLASH / "src/libpgas/include"), |
| 59 | + str(BENCH_SRC), |
| 60 | + str(libfaiss), |
| 61 | + str(splash_backend), |
| 62 | + "-lopenblas", "-lpthread", "-lrt", "-ldl", |
| 63 | + "-o", str(BENCH_BIN), |
| 64 | + ]) |
| 65 | + |
| 66 | + |
| 67 | +def start_pool(args, log_path): |
| 68 | + log = open(log_path, "w", encoding="utf-8") |
| 69 | + if args.pool_provider == "splash": |
| 70 | + server = SPLASH / "build/cxl_shmem_server" |
| 71 | + if not server.exists(): |
| 72 | + log.close() |
| 73 | + raise FileNotFoundError(f"missing {server}; build Splash first") |
| 74 | + cmd = [ |
| 75 | + str(server), |
| 76 | + "--name", args.pool_name, |
| 77 | + "--size", str(args.capacity_mb * 1024 * 1024), |
| 78 | + "--latency", str(args.latency_ns), |
| 79 | + ] |
| 80 | + cwd = SPLASH |
| 81 | + else: |
| 82 | + server = ROOT / "build/cxlmemsim_server" |
| 83 | + if not server.exists(): |
| 84 | + run(["cmake", "--build", str(ROOT / "build"), "--target", "cxlmemsim_server", "-j", "4"]) |
| 85 | + cmd = [ |
| 86 | + str(server), |
| 87 | + "--comm-mode", "pgas-shm", |
| 88 | + "--pgas-shm-name", args.pool_name, |
| 89 | + "--capacity", str(args.capacity_mb), |
| 90 | + "--default_latency", str(args.latency_ns), |
| 91 | + ] |
| 92 | + cwd = ROOT |
| 93 | + proc = subprocess.Popen(cmd, stdout=log, stderr=subprocess.STDOUT, cwd=cwd) |
| 94 | + time.sleep(2.0) |
| 95 | + if proc.poll() is not None: |
| 96 | + log.close() |
| 97 | + raise RuntimeError(f"cxlmemsim_server exited early, see {log_path}") |
| 98 | + return proc, log |
| 99 | + |
| 100 | + |
| 101 | +def stop_proc(proc, log): |
| 102 | + if proc and proc.poll() is None: |
| 103 | + proc.send_signal(signal.SIGINT) |
| 104 | + try: |
| 105 | + proc.wait(timeout=5) |
| 106 | + except subprocess.TimeoutExpired: |
| 107 | + proc.terminate() |
| 108 | + proc.wait(timeout=5) |
| 109 | + if log: |
| 110 | + log.close() |
| 111 | + |
| 112 | + |
| 113 | +def bench_cmd(args, storage, node): |
| 114 | + return [ |
| 115 | + str(BENCH_BIN), |
| 116 | + "--storage", storage, |
| 117 | + "--pool", args.pool_name, |
| 118 | + "--node", str(node), |
| 119 | + "--nb", str(args.nb), |
| 120 | + "--nq", str(args.nq), |
| 121 | + "--dim", str(args.dim), |
| 122 | + "--k", str(args.k), |
| 123 | + "--iters", str(args.iters), |
| 124 | + "--block", str(args.block), |
| 125 | + "--pool-mb", str(args.capacity_mb), |
| 126 | + "--threads", str(args.threads), |
| 127 | + ] |
| 128 | + |
| 129 | + |
| 130 | +def run_node(args, storage, node, out_dir): |
| 131 | + env = os.environ.copy() |
| 132 | + env["PGAS_LOCAL_NODE"] = str(node) |
| 133 | + env["PGAS_NUM_NODES"] = "2" |
| 134 | + log_path = out_dir / f"{storage}_node{node}.log" |
| 135 | + with open(log_path, "w", encoding="utf-8") as log: |
| 136 | + proc = subprocess.run(bench_cmd(args, storage, node), stdout=log, stderr=subprocess.STDOUT, env=env, cwd=ROOT) |
| 137 | + if proc.returncode != 0: |
| 138 | + raise RuntimeError(f"{storage} node {node} failed, see {log_path}") |
| 139 | + return log_path |
| 140 | + |
| 141 | + |
| 142 | +def parse_results(log_paths, csv_path): |
| 143 | + rows = [] |
| 144 | + for log_path in log_paths: |
| 145 | + text = log_path.read_text(encoding="utf-8", errors="replace") |
| 146 | + for line in text.splitlines(): |
| 147 | + if not line.startswith("FAISS_SPLASH_RESULT"): |
| 148 | + continue |
| 149 | + row = {} |
| 150 | + for part in line.split()[1:]: |
| 151 | + key, value = part.split("=", 1) |
| 152 | + row[key] = value |
| 153 | + rows.append(row) |
| 154 | + fields = ["node", "storage", "nb", "nq", "dim", "k", "iters", "block", "db_mb", |
| 155 | + "total_ms", "pool_write_ms", "pool_read_ms", "qps", "checksum"] |
| 156 | + with open(csv_path, "w", encoding="utf-8") as f: |
| 157 | + f.write(",".join(fields) + "\n") |
| 158 | + for row in rows: |
| 159 | + f.write(",".join(row.get(field, "") for field in fields) + "\n") |
| 160 | + return rows |
| 161 | + |
| 162 | + |
| 163 | +def main(): |
| 164 | + parser = argparse.ArgumentParser(description="Run FAISS CPU against native DRAM and Splash/CXLMemSim SHMEM pool.") |
| 165 | + parser.add_argument("--nb", type=int, default=10000) |
| 166 | + parser.add_argument("--nq", type=int, default=64) |
| 167 | + parser.add_argument("--dim", type=int, default=64) |
| 168 | + parser.add_argument("--k", type=int, default=10) |
| 169 | + parser.add_argument("--iters", type=int, default=1) |
| 170 | + parser.add_argument("--block", type=int, default=2048) |
| 171 | + parser.add_argument("--capacity-mb", type=int, default=256) |
| 172 | + parser.add_argument("--latency-ns", type=int, default=100) |
| 173 | + parser.add_argument("--threads", type=int, default=1) |
| 174 | + parser.add_argument("--jobs", type=int, default=os.cpu_count() or 4) |
| 175 | + parser.add_argument("--pool-name", default="/faiss_cxl_pool") |
| 176 | + parser.add_argument("--pool-provider", choices=["splash", "cxlmemsim"], default="splash") |
| 177 | + parser.add_argument("--skip-build", action="store_true") |
| 178 | + args = parser.parse_args() |
| 179 | + |
| 180 | + if not FAISS_SRC.exists(): |
| 181 | + raise FileNotFoundError(f"missing FAISS checkout at {FAISS_SRC}") |
| 182 | + if not SPLASH.exists(): |
| 183 | + raise FileNotFoundError(f"missing Splash checkout at {SPLASH}") |
| 184 | + |
| 185 | + ARTIFACT.mkdir(parents=True, exist_ok=True) |
| 186 | + out_dir = ARTIFACT / time.strftime("run_%Y%m%d_%H%M%S") |
| 187 | + out_dir.mkdir() |
| 188 | + |
| 189 | + if not args.skip_build: |
| 190 | + configure_and_build_faiss(args.jobs) |
| 191 | + build_benchmark() |
| 192 | + |
| 193 | + logs = [] |
| 194 | + for node in (0, 1): |
| 195 | + logs.append(run_node(args, "native", node, out_dir)) |
| 196 | + |
| 197 | + proc = None |
| 198 | + pool_log = None |
| 199 | + try: |
| 200 | + proc, pool_log = start_pool(args, out_dir / f"{args.pool_provider}_pool.log") |
| 201 | + for node in (0, 1): |
| 202 | + logs.append(run_node(args, "cxl-pool", node, out_dir)) |
| 203 | + finally: |
| 204 | + stop_proc(proc, pool_log) |
| 205 | + |
| 206 | + rows = parse_results(logs, out_dir / "results.csv") |
| 207 | + latest = ARTIFACT / "latest" |
| 208 | + if latest.exists() or latest.is_symlink(): |
| 209 | + latest.unlink() |
| 210 | + latest.symlink_to(out_dir, target_is_directory=True) |
| 211 | + |
| 212 | + print(f"wrote {out_dir / 'results.csv'}") |
| 213 | + for row in rows: |
| 214 | + print( |
| 215 | + f"node={row['node']} storage={row['storage']} " |
| 216 | + f"total_ms={row['total_ms']} pool_read_ms={row['pool_read_ms']} qps={row['qps']}" |
| 217 | + ) |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == "__main__": |
| 221 | + main() |
0 commit comments