Skip to content

Commit ea692fc

Browse files
[Feat] RAG example (#21)
* [feat] add rag agent example * Update examples/rag_agent/wiki_retriever_mcp/wiki_retriever_mcp.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update examples/rag_agent/rag_agent.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update examples/rag_agent/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update examples/rag_agent/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] rag agent * [fmt] black formatter * [fix] lint with black * [fix] pre-commit linter * [fix] change rag example folder name from rag_agent to rag --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 989027b commit ea692fc

6 files changed

Lines changed: 657 additions & 0 deletions

File tree

examples/rag/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# RAG Agent Example
2+
3+
This example originally runs on a single node with four GPUs, each requiring at least 40GB of memory.
4+
5+
1. Prepare the RAG dataset in the wiki_retriever_mcp folder. Wiki chunks (`nq_list.pkl`) and Faiss index (`nq_hnsw_faiss_n32e40.index`) are required. (Full wiki dump files are huge, additional information will be provided later)
6+
2. Prepare the training data in the `data` folder. Download from [here](https://drive.google.com/drive/folders/1hEqOY4EbplUB5ew-8UPFhV_5QU2j7WCN?usp=drive_link). `musique_train.parquet` and `musique_dev_128.parquet` are required.
7+
3. Set up the environment for wiki retriever MCP: `bash wiki_retriever_install.sh`. This will install the required packages and set up the environment for the wiki retriever MCP.
8+
4. Start the wiki retriever MCP: `python wiki_retriever_mcp.py`. This will start the wiki retriever MCP server.
9+
5. Start Ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting Ray.
10+
6. Run the agent: `python rag_agent.py`. This automatically launches 12 agent workers by default.
11+
7. In another terminal, launch the training server: `bash train.sh`.
12+
13+
## Evaluation
14+
15+
Results are coming soon.

examples/rag/rag_agent.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import re
5+
import shutil
6+
import sys
7+
import tempfile
8+
import time
9+
from typing import Any, Literal, Optional
10+
11+
import dotenv
12+
import termcolor
13+
from agents import (
14+
Agent,
15+
Runner,
16+
function_tool,
17+
gen_trace_id,
18+
set_trace_processors,
19+
set_tracing_disabled,
20+
trace,
21+
)
22+
from agents.extensions.models.litellm_model import LitellmModel
23+
from agents.mcp import MCPServer, MCPServerSse
24+
from agents.model_settings import ModelSettings
25+
from agents.tracing.processors import BatchTraceProcessor, ConsoleSpanExporter
26+
from utils import compute_scores
27+
28+
import agentlightning
29+
from agentlightning import (
30+
LLM,
31+
LitAgent,
32+
NamedResources,
33+
Trainer,
34+
configure_logger,
35+
reward,
36+
)
37+
38+
configure_logger()
39+
40+
agent_prompt = """You are an assistant who answers questions using Wikipedia retriever. Answer the question using only the retrieved passages. Verify your answer directly against the text.
41+
42+
After each search:
43+
- Summarize findings.
44+
- Decide if info is sufficient.
45+
- If sufficient: reply in <answer>...</answer> with your answer. The answer must be extremely concise: a single word or a few words only.
46+
- If not: suggest the next search needed to fill info gaps. The system will return top 3 relevant Wikipedia chunks.
47+
- Explain your reasoning for the chosen action.
48+
49+
Repeat as needed. When done, wrap your final, concise answer in <answer> tags."""
50+
51+
52+
class RAGAgent(LitAgent):
53+
def __init__(self):
54+
self.mcp_server_url = "http://127.0.0.1:8099/sse"
55+
56+
async def training_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
57+
llm: LLM = resources.get("main_llm")
58+
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
59+
async with MCPServerSse(
60+
name="wiki_retriever_mcp",
61+
params={"url": self.mcp_server_url},
62+
) as server:
63+
agent = Agent(
64+
model=LitellmModel(model="hosted_vllm/" + llm.model, base_url=llm.endpoint),
65+
model_settings=ModelSettings(
66+
max_tokens=4096,
67+
temperature=0.7,
68+
),
69+
name="Assistant",
70+
instructions=agent_prompt,
71+
mcp_servers=[server],
72+
)
73+
result = await Runner.run(agent, task["question"])
74+
answer = result.final_output
75+
reward = compute_scores(answer, str(task["answer"]))
76+
print(
77+
"question:{} answer: {} ground_truth: {} reward: {}".format(
78+
task["question"], answer, task["answer"], reward
79+
)
80+
)
81+
return reward
82+
83+
async def validation_rollout_async(self, task: Any, rollout_id: str, resources: NamedResources) -> Any:
84+
llm: LLM = resources.get("main_llm")
85+
resources = {
86+
"main_llm": LLM(
87+
endpoint=llm.endpoint,
88+
model=llm.model,
89+
sampling_parameters={"temperature": 0.7},
90+
)
91+
}
92+
return await self.training_rollout_async(task, rollout_id, resources)
93+
94+
95+
if __name__ == "__main__":
96+
Trainer(n_workers=12).fit(RAGAgent(), "http://localhost:9999/")

examples/rag/train.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
export N_GPUS=1
6+
export BASE_MODEL=Qwen/Qwen3-1.7B
7+
export DATA_DIR=data
8+
export ROLLOUT_TP_SIZE=1
9+
export EXPERIMENT_NAME=rag_agent
10+
export PROJECT_NAME=AgentLightning
11+
12+
echo "Starting training script..."
13+
14+
python -m agentlightning.verl \
15+
algorithm.adv_estimator=grpo \
16+
data.train_files=${DATA_DIR}/musique_train.parquet \
17+
data.val_files=${DATA_DIR}/musique_dev_128.parquet \
18+
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
19+
trainer.n_gpus_per_node=${N_GPUS} \
20+
data.train_batch_size=32 \
21+
actor_rollout_ref.rollout.n=4 \
22+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
23+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
24+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
25+
actor_rollout_ref.rollout.multi_turn.format=hermes \
26+
actor_rollout_ref.model.path=${BASE_MODEL} \
27+
data.max_prompt_length=4096 \
28+
data.max_response_length=2048 \
29+
data.truncation='error' \
30+
trainer.val_before_train=True \
31+
actor_rollout_ref.actor.optim.lr=1e-6 \
32+
actor_rollout_ref.model.use_remove_padding=True \
33+
actor_rollout_ref.actor.use_kl_loss=False \
34+
actor_rollout_ref.actor.kl_loss_coef=0.000 \
35+
actor_rollout_ref.actor.entropy_coeff=0 \
36+
actor_rollout_ref.actor.clip_ratio_low=0.2 \
37+
actor_rollout_ref.actor.clip_ratio_high=0.3 \
38+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
39+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
40+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
41+
actor_rollout_ref.rollout.name=vllm \
42+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
43+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
44+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
45+
algorithm.use_kl_in_reward=False \
46+
trainer.critic_warmup=0 \
47+
trainer.logger=['console','wandb'] \
48+
trainer.project_name=${PROJECT_NAME} \
49+
trainer.experiment_name=${EXPERIMENT_NAME} \
50+
trainer.nnodes=1 \
51+
trainer.save_freq=40 \
52+
trainer.test_freq=20 \
53+
trainer.total_epochs=2 $@

0 commit comments

Comments
 (0)