forked from ashishMenon05/NEXUS-AI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
157 lines (131 loc) · 7.11 KB
/
inference.py
File metadata and controls
157 lines (131 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
"""
NEXUS Inference Script — OpenEnv Competition Submission
Must be placed at the root of the project repo.
Required environment variables (set in .env or passed externally):
API_BASE_URL - OpenAI-compatible endpoint. Examples:
HF Space: https://api-inference.huggingface.co/v1
Local dev: http://localhost:11434/v1 (Ollama)
MODEL_NAME - Model identifier, e.g., "Qwen/Qwen2.5-3B-Instruct" (HF)
or "qwen2.5:3b" (Ollama)
HF_TOKEN - Your HuggingFace API key when using HF Inference endpoints.
For local Ollama usage set this to "ollama" or any string.
The OpenAI client transparently works with HF, Ollama, or any compatible API.
Output format (stdout only):
[START] task=<task_name> env=nexus-incident-investigation model=<model>
[STEP] step=<n> action="<text>" reward=<0.00> done=<true|false> error=null
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
"""
import os
import sys
import asyncio
import re
from pathlib import Path
# ── Resolve project root and load .env ────────────────────────────────────────
ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT / "backend"))
from dotenv import load_dotenv
load_dotenv(ROOT / "backend" / ".env")
# ── Required variables ─────────────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
API_KEY = os.environ.get("OPENAI_API_KEY", HF_TOKEN) if os.environ.get("OPENAI_API_KEY", HF_TOKEN) else os.environ.get("API_KEY", "ollama")
# ── OpenAI client (required by competition spec) ───────────────────────────────
from openai import OpenAI
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# ── Backend imports ────────────────────────────────────────────────────────────
from backend.core.environment import NexusEnvironment
from backend.api.schemas.action import NexusAction, ToolCall
# ── Tool-call parser ───────────────────────────────────────────────────────────
def parse_tool_calls(text: str) -> list:
tool_calls = []
for match in re.finditer(r"TOOL:\s*([a-zA-Z0-9_]+)\(([^)]*)\)", text):
name = match.group(1)
args_s = match.group(2)
params = {}
for kv in re.finditer(r"(\w+)=['\"]?([^,'\"]+)['\"]?", args_s):
params[kv.group(1)] = kv.group(2)
tool_calls.append(ToolCall(tool_name=name, params=params))
return tool_calls
# ── Tasks definition (3 required by spec) ─────────────────────────────────────
TASKS = [
{"name": "software-incident", "difficulty": "easy"},
{"name": "business-process-failure", "difficulty": "medium"},
{"name": "cascade-system-failure", "difficulty": "hard"},
]
SYSTEM_PROMPT = (
"You are an expert incident investigator. Investigate the provided system incident.\n"
"Use tools to gather evidence. Format tool calls EXACTLY as:\n"
" TOOL: tool_name(param='value')\n"
"Available tools: read_logs, check_config, query_database, check_service_status, "
"run_diagnostic, propose_fix, verify_fix\n"
"Be concise. Produce a root-cause hypothesis and propose a fix."
)
MAX_STEPS = int(os.environ.get("MAX_STEPS", "8"))
# ── Main inference loop ────────────────────────────────────────────────────────
async def run():
env = NexusEnvironment()
for task in TASKS:
_print(f"[START] task={task['name']} env=nexus-incident-investigation model={MODEL_NAME}")
obs = await env.reset(task["name"], seed=42)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
]
done = False
step_n = 0
rewards = []
while not done and step_n < MAX_STEPS:
step_n += 1
# Build user message from observation
user_content = (
f"Scenario: {obs.scenario_description}\n"
f"Context: {obs.scenario_context}\n"
f"Clues found so far: {obs.clues_found}\n"
f"Round {obs.round}. Investigate and call tools."
)
if obs.tool_results:
user_content += "\nLast tool results:\n" + "\n".join(
f" {tr.tool_name}: {tr.result[:200]}" for tr in obs.tool_results
)
messages.append({"role": "user", "content": user_content})
# ── LLM call (OpenAI client, required by spec) ─────────────────
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=300,
temperature=0.7
)
action_text = resp.choices[0].message.content or ""
except Exception as e:
action_text = f"Error calling LLM: {e}"
messages.append({"role": "assistant", "content": action_text})
# ── Parse tool calls & step environment ────────────────────────
tool_calls = parse_tool_calls(action_text)
action = NexusAction(
agent_id="agent_a",
message=action_text,
tool_calls=tool_calls,
confidence=0.8
)
obs, reward, done, info = await env.step(action)
rewards.append(reward)
# ── [STEP] log (exact format required) ────────────────────────
clean = action_text.replace("\n", " ")[:200]
_print(
f'[STEP] step={step_n} action="{clean}" '
f'reward={reward:.2f} done={str(done).lower()} error=null'
)
# ── [END] log ──────────────────────────────────────────────────────
final_score = info.get("final_score", rewards[-1] if rewards else 0.0)
success = info.get("success", final_score >= 0.5)
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
_print(
f"[END] success={str(success).lower()} steps={step_n} "
f"score={final_score:.3f} rewards={rewards_str}"
)
def _print(line: str):
print(line, flush=True)
if __name__ == "__main__":
asyncio.run(run())