Skip to content

Commit 78182cd

Browse files
authored
Openclaw exp (#17)
* deep-fin-pre-commit-patch * revise openclaw training * add illustration * add better reward for openclaw agent build
1 parent a77aa3c commit 78182cd

File tree

7 files changed

+617
-75
lines changed

7 files changed

+617
-75
lines changed

docs/en/example_train_multi_model.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ graph TB
9090
C -->|end_episode + reward_14b| S2
9191
```
9292

93+
![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png)
94+
95+
9396
**Architecture Explanation**:
9497

9598
- **Swarm Server 1 (Port 10086)**: Hosts the 7B model, responsible for Agent 1 and Agent 3's inference and training

docs/en/example_train_multi_model.zh.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ graph TB
8888
C -->|end_episode + reward_14b| S2
8989
```
9090

91+
![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png)
92+
93+
9194
**架构说明**
9295

9396
- **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练
@@ -176,6 +179,8 @@ sequenceDiagram
176179
4. 将各自的奖励汇报给对应的 Swarm Server
177180
5. 两个 Server 独立执行策略梯度更新
178181

182+
183+
179184
## 训练曲线
180185

181186
![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# OpenClaw Reward Cheatsheet
2+
3+
## Run the test
4+
5+
```bash
6+
cd agentjet/tutorial/opencode_build_openclaw_agent
7+
8+
# pointwise (default)
9+
DASHSCOPE_API_KEY=your_key python test_reward.py
10+
11+
# listwise
12+
REWARD_MODE=listwise DASHSCOPE_API_KEY=your_key python test_reward.py
13+
```
14+
15+
## Run the training endpoint
16+
17+
```bash
18+
# pointwise (default)
19+
AJET_SWARM_URL=http://localhost:10086 \
20+
DASHSCOPE_API_KEY=your_key \
21+
REWARD_MODE=pointwise \
22+
python fake_vllm_endpoint.py
23+
24+
# listwise
25+
AJET_SWARM_URL=http://localhost:10086 \
26+
DASHSCOPE_API_KEY=your_key \
27+
REWARD_MODE=listwise \
28+
python fake_vllm_endpoint.py
29+
```
30+
31+
## Reward modes
32+
33+
| Mode | Description |
34+
|------|-------------|
35+
| `pointwise` | Each response scored independently (0.0–1.0) |
36+
| `listwise` | All responses ranked together (best=1.0, worst=0.0) |
37+
38+
## Environment variables
39+
40+
| Variable | Default | Description |
41+
|----------|---------|-------------|
42+
| `REWARD_MODE` | `pointwise` | `pointwise` or `listwise` |
43+
| `DASHSCOPE_API_KEY` || DashScope API key (required) |
44+
| `JUDGE_MODEL` | `qwen-plus` | Judge model name |
45+
| `JUDGE_BASE_URL` | DashScope endpoint | Judge model base URL |
46+
| `AJET_SWARM_URL` | `http://localhost:10086` | Swarm server URL |
47+
| `NUM_REPEAT` | `4` | GRPO N (responses per query) |

tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sys
2626
sys.path.insert(0, os.path.dirname(__file__))
2727

28-
from on_user_submit_new_requests import on_user_submit_new_requests
28+
from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history
2929
from on_compute_relative_reward import on_compute_relative_reward
3030

3131
# Configuration
@@ -91,6 +91,14 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i
9191
json_data = await request.json()
9292
json_data["stream"] = is_stream
9393

94+
# Remove fields not supported by vLLM to avoid warnings
95+
UNSUPPORTED_FIELDS = {"strict", "store"}
96+
for field in UNSUPPORTED_FIELDS:
97+
json_data.pop(field, None)
98+
# Also remove 'strict' from response_format if present
99+
if "response_format" in json_data and isinstance(json_data["response_format"], dict):
100+
json_data["response_format"].pop("strict", None)
101+
94102
async with httpx.AsyncClient(timeout=300.0) as client:
95103
resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers)
96104
resp.raise_for_status()
@@ -200,7 +208,7 @@ async def handle_one2many_request(request: Request, request_id: str) -> Dict | L
200208

201209
valid_results = await run_all_episodes(request, is_stream)
202210
all_answers = [extract_assistant_message(r.response) for r in valid_results]
203-
rewards = await on_compute_relative_reward(valid_results, all_answers)
211+
rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query)
204212

205213
await finalize_episodes(task, valid_results, rewards)
206214

@@ -259,7 +267,7 @@ async def health_check():
259267
@app.get("/requests")
260268
async def get_requests():
261269
"""Get all recorded user requests."""
262-
return {"requests": USER_REQUEST_RECORD}
270+
return {"requests": get_query_history()}
263271

264272

265273
if __name__ == "__main__":

0 commit comments

Comments
 (0)