-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexample.py
More file actions
103 lines (85 loc) · 3.09 KB
/
Copy pathexample.py
File metadata and controls
103 lines (85 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import sys
import time
from pathlib import Path
import requests
from .config import AppConfig
_cfg = AppConfig()
BASE_URL = os.environ.get("CYTEONTO_URL", "https://cyteonto.nygen.io").rstrip("/")
LLM_API_KEY = os.environ.get("LLM_API_KEY")
EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY")
PAYLOAD: dict = {
"authorLabels": [
"alveolar macrophage",
"regulatory T cell",
"CD8-positive, alpha-beta T cell",
],
"algorithms": {
"algo1": ["lung macrophage", "Treg", "CD8 T cell"],
"algo2": ["alveolar mac", "T regulatory cell", "cytotoxic T cell"],
},
"llmProvider": _cfg.DEFAULT_LLM_PROVIDER,
"llmModel": _cfg.DEFAULT_LLM_MODEL,
"embeddingProvider": _cfg.DEFAULT_EMBEDDING_PROVIDER,
"embeddingModel": _cfg.DEFAULT_EMBEDDING_MODEL,
"maxDescriptionConcurrency": 50,
"embeddingMaxConcurrent": 50,
"usePubmedTool": False,
"reasoning": False,
}
POLL_INTERVAL_S = 10
POLL_TIMEOUT_S = 60 * 60
RESULT_FORMAT = "csv"
OUTPUT_DIR = Path("./cyteonto_results")
def submit(payload: dict) -> str:
body = dict(payload)
if LLM_API_KEY:
body["llmApiKey"] = LLM_API_KEY
if EMBEDDING_API_KEY:
body["embeddingApiKey"] = EMBEDDING_API_KEY
resp = requests.post(f"{BASE_URL}/compare", json=body, timeout=30)
resp.raise_for_status()
data = resp.json()
print(f"[submit] runId={data['runId']} state={data['state']}")
return data["runId"]
def poll(run_id: str, interval_s: int, timeout_s: int) -> dict:
deadline = time.time() + timeout_s
last_state = None
while time.time() < deadline:
resp = requests.get(f"{BASE_URL}/status/{run_id}", timeout=30)
resp.raise_for_status()
status = resp.json()
if status["state"] != last_state:
print(f"[status] {status['state']}")
if status["state"] == "completed" and status.get("modelPairUsage"):
print(f"[status] modelPairUsage={status['modelPairUsage']}")
last_state = status["state"]
if status["state"] in ("completed", "failed"):
return status
time.sleep(interval_s)
raise TimeoutError(f"Run {run_id} did not finish within {timeout_s}s")
def fetch_result(run_id: str, fmt: str, out_dir: Path) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
resp = requests.get(
f"{BASE_URL}/result/{run_id}",
params={"format": fmt},
timeout=60,
)
resp.raise_for_status()
suffix = "csv" if fmt == "csv" else "json"
out_path = out_dir / f"{run_id}.{suffix}"
out_path.write_bytes(resp.content)
return out_path
def main() -> int:
health = requests.get(f"{BASE_URL}/health", timeout=15)
health.raise_for_status()
run_id = submit(PAYLOAD)
status = poll(run_id, POLL_INTERVAL_S, POLL_TIMEOUT_S)
if status["state"] == "failed":
print(f"[failed] {status.get('error')}", file=sys.stderr)
return 1
path = fetch_result(run_id, RESULT_FORMAT, OUTPUT_DIR)
print(f"[done] {status['numRows']} rows saved to {path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())