Skip to content

Commit 7e05fb6

Browse files
committed
try vibe training
1 parent 17ca418 commit 7e05fb6

File tree

10 files changed

+935
-3
lines changed

10 files changed

+935
-3
lines changed

docs/en/swarm_vibe_coding.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Vibe Coding with AgentJet Swarm
2+
3+
AgentJet Swarm client is so simple that even LLMs can tune model using its APIs.
4+
5+
Here is an example:
6+
7+
```txt
8+
Your task:
9+
- Write an intelligent agent that learns the CountDown task (You are an agent specialized in solving countdown number puzzles. Given a target number and a list of source numbers, find a way to reach the target number using basic arithmetic operations (+, -, *, /). Each source number can only be used once.)
10+
- I hope to use the base model '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
11+
- Train using 8 GPUs
12+
- Batch Size 16
13+
- I currently do not have a dataset, you need to help me mock a small amount of data for testing
14+
15+
Your skills (First read the SKILL file to acquire necessary knowledge):
16+
ajet/copilot/write-swarm-client/SKILL.md
17+
```
18+
19+
Copy and paste the prompt above into opencode or claude-code, and then hit `ajet-swarm start` and `python /path/to/ai/generated/agent_roll.py`,
20+
and wait for the training to finish.
21+
22+
Reference result:
23+
24+
<div align="center">
25+
<img width="600" alt="image" src="https://img.alicdn.com/imgextra/i2/O1CN01u5JHH521QRGeQAFsL_!!6000000006979-2-tps-1200-600.png"/>
26+
</div>
27+
28+
29+
30+

docs/en/swarm_with_ai_coding.md

Whitespace-only changes.

tutorial/example_math_swarm/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
GRPO_N = 4 # grpo group size
1818
NUM_EPOCH = 10000
1919
AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
20-
REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct")
20+
REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct")
2121
REMOTE_BATCH_SIZE = 32
2222
REMOTE_ALLOCATE_GPU_PER_NODE = 8
2323

@@ -28,7 +28,7 @@ def main():
2828
reader_type = "huggingface_dat_repo",
2929
reader_config = AjetTaskReader(
3030
huggingface_dat_repo = HuggingfaceDatRepo(
31-
dataset_path = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main",
31+
dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic",
3232
# dataset_path = "openai/gsm8k",
3333
# dataset_name = "main",
3434
)
@@ -46,7 +46,7 @@ def main():
4646
batch_size=REMOTE_BATCH_SIZE,
4747
num_repeat=GRPO_N,
4848
),
49-
# force_restart=True,
49+
force_restart=True,
5050
)
5151

5252
def rollout(task):
-253 KB
Binary file not shown.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Generate an agent / agent loop with AgentJet Swarm and train it with one key
2+
3+
Use prompt below in opencode or claudecode to generate a one-key-to-tune agent (result is in `tutorial/opencode_build_countdown_agent`, generated by `claude sonnet 4.5`)
4+
5+
=============================
6+
7+
Your task:
8+
- Write an intelligent agent that learns the CountDown task (You are an agent specialized in solving countdown number puzzles. Given a target number and a list of source numbers, find a way to reach the target number using basic arithmetic operations (+, -, *, /). Each source number can only be used once.)
9+
- I hope to use the base model '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
10+
- Train using 8 GPUs
11+
- Batch Size 16
12+
- I currently do not have a dataset, you need to help me mock a small amount of data for testing
13+
14+
Your skills (First read the SKILL file to acquire necessary knowledge):
15+
ajet/copilot/write-swarm-client/SKILL.md
16+
17+
=============================
18+
19+
你的任务:
20+
- 编写一个学习CountDown任务的智能体 (You are an agent specialized in solving countdown number puzzles. Given a target number and a list of source numbers, find a way to reach the target number using basic arithmetic operations (+, -, *, /). And each source number can only be used once.)
21+
- 我希望使用基础模型 '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
22+
- 使用 8 GPU 训练
23+
- Batch Size 16
24+
- 我目前没有数据集,你需要帮助我mock少量数据以供测试
25+
26+
你的skill(首先读取该SKILL文件,获取必要知识):
27+
ajet/copilot/write-swarm-client/SKILL.md
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# ------- AI GENERATED --------
2+
# ------- [Read tutorial/opencode_build_countdown_agent.prompt.md] --------
3+
4+
"""
5+
6+
CountDown Number Puzzle Solver Agent
7+
8+
This package contains a trainable agent for solving CountDown number puzzles.
9+
"""
10+
11+
from .agent_run import run_agent_and_compute_reward
12+
13+
__all__ = ["run_agent_and_compute_reward"]
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# ------- AI GENERATED --------
2+
# ------- [Read tutorial/opencode_build_countdown_agent.prompt.md] --------
3+
4+
"""
5+
CountDown Agent Training Script (Swarm Client)
6+
7+
This script connects to the AgentJet Swarm server and trains the countdown agent.
8+
9+
Usage:
10+
python -m tutorial.countdown_agent.agent_roll
11+
12+
Before running:
13+
1. Start the swarm server: ajet-swarm start
14+
2. Ensure the dataset is generated: python tutorial/countdown_agent/generate_countdown_dataset.py
15+
3. Update the configuration variables below to match your setup
16+
"""
17+
18+
from ajet.copilot.job import AgentJetJob
19+
from ajet.tuner_lib.experimental.as_swarm_client import (
20+
SwarmClient,
21+
run_episodes_until_all_complete,
22+
)
23+
from ajet.default_config.ajet_default import (
24+
AjetTaskReader,
25+
JsonlDatasetFile,
26+
JsonlTrainingFp,
27+
)
28+
from ajet.task_reader import RouterTaskReader
29+
from .agent_run import run_agent_and_compute_reward
30+
31+
32+
# --------- Configurations that take effect locally -------------
33+
LOCAL_GRPO_N = 4 # GRPO group size (number of rollouts per task)
34+
LOCAL_NUM_EPOCH = 100 # Number of training epochs
35+
LOCAL_DATASET_PATH = "./tutorial/countdown_agent/countdown_dataset/train.jsonl"
36+
REMOTE_SWARM_URL = "http://localhost:10086" # Swarm server URL
37+
38+
# --------- Configurations that take effect remotely (on swarm server) -------------
39+
REMOTE_BATCH_SIZE = 16 # Batch size for training (as specified by user)
40+
REMOTE_ALLOCATE_GPU_PER_NODE = 8 # Number of GPUs to use (as specified by user)
41+
REMOTE_TRAIN_MODEL = (
42+
"/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct"
43+
)
44+
45+
46+
def main():
47+
"""
48+
Main training loop for CountDown agent.
49+
"""
50+
51+
# Load the CountDown dataset
52+
print(f"Loading dataset from: {LOCAL_DATASET_PATH}")
53+
dataset = RouterTaskReader(
54+
reader_type="jsonl_dataset_file",
55+
reader_config=AjetTaskReader(
56+
jsonl_dataset_file=JsonlDatasetFile(
57+
training=JsonlTrainingFp(file_path=LOCAL_DATASET_PATH)
58+
)
59+
),
60+
)
61+
62+
# Connect to swarm server and configure training
63+
print(f"Connecting to swarm server at: {REMOTE_SWARM_URL}")
64+
swarm_worker = SwarmClient(REMOTE_SWARM_URL)
65+
66+
# Configure and start the training engine
67+
print("Configuring training parameters...")
68+
yaml_job = AgentJetJob(
69+
algorithm="grpo", # Using GRPO (Group Relative Policy Optimization)
70+
project_name="countdown-agent",
71+
experiment_name="countdown_solver_7b",
72+
n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
73+
model=REMOTE_TRAIN_MODEL,
74+
batch_size=REMOTE_BATCH_SIZE,
75+
num_repeat=LOCAL_GRPO_N,
76+
)
77+
78+
print("Starting swarm engine...")
79+
swarm_worker.auto_sync_train_config_and_start_engine(yaml_job)
80+
81+
print("\n" + "=" * 80)
82+
print("Training started!")
83+
print(f"Model: {REMOTE_TRAIN_MODEL}")
84+
print(f"GPUs: {REMOTE_ALLOCATE_GPU_PER_NODE}")
85+
print(f"Batch size: {REMOTE_BATCH_SIZE}")
86+
print(f"GRPO group size: {LOCAL_GRPO_N}")
87+
print(f"Epochs: {LOCAL_NUM_EPOCH}")
88+
print("=" * 80 + "\n")
89+
90+
def rollout(task):
91+
"""
92+
Execute a single episode (rollout) of the agent.
93+
94+
Args:
95+
task: The countdown problem to solve
96+
97+
Returns:
98+
The reward obtained (or None on failure)
99+
"""
100+
try:
101+
# Begin episode and get API credentials
102+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode()
103+
104+
# Execute agent and compute reward
105+
workflow_output = run_agent_and_compute_reward(
106+
task, api_baseurl_key.base_url, api_baseurl_key.api_key
107+
)
108+
109+
# Report results back to swarm server
110+
swarm_worker.end_episode(task, episode_uuid, workflow_output)
111+
112+
# Print rollout statistics
113+
swarm_worker.print_rollout_stat()
114+
115+
return workflow_output.reward
116+
117+
except Exception as e:
118+
print(f"Error during rollout: {e}")
119+
return None
120+
121+
# Training loop
122+
next_batch = []
123+
total_episodes = 0
124+
125+
for epoch in range(LOCAL_NUM_EPOCH):
126+
print(f"\n{'=' * 80}")
127+
print(f"Epoch {epoch + 1}/{LOCAL_NUM_EPOCH}")
128+
print(f"{'=' * 80}\n")
129+
130+
for task_idx, task in enumerate(dataset.generate_training_tasks()):
131+
# For each task, perform LOCAL_GRPO_N rollouts (GRPO group)
132+
for _ in range(LOCAL_GRPO_N):
133+
next_batch.append(task)
134+
135+
# When batch is full, execute all episodes
136+
if len(next_batch) >= (REMOTE_BATCH_SIZE * LOCAL_GRPO_N):
137+
print(f"\nExecuting batch of {len(next_batch)} episodes...")
138+
139+
# Execute episodes with automatic retry on failure
140+
episode_results = run_episodes_until_all_complete(
141+
next_batch, func=rollout, auto_retry=True
142+
)
143+
144+
total_episodes += len(next_batch)
145+
146+
# Print batch results
147+
successful = sum(
148+
1 for r in episode_results if r is not None and r > 0
149+
)
150+
avg_reward = (
151+
sum(r for r in episode_results if r is not None)
152+
/ len(episode_results)
153+
if episode_results
154+
else 0
155+
)
156+
157+
print(f"\nBatch completed:")
158+
print(f" Total episodes: {len(next_batch)}")
159+
print(f" Successful: {successful}")
160+
print(f" Average reward: {avg_reward:.3f}")
161+
print(f" Total episodes so far: {total_episodes}")
162+
163+
next_batch.clear()
164+
165+
print(f"\nEpoch {epoch + 1} completed!")
166+
167+
print("\n" + "=" * 80)
168+
print("Training completed!")
169+
print(f"Total episodes executed: {total_episodes}")
170+
print("=" * 80)
171+
172+
return None
173+
174+
175+
if __name__ == "__main__":
176+
main()

0 commit comments

Comments
 (0)