Skip to content

Commit 5705288

Browse files
committed
Update README and complete the inference implementation
1 parent d204eec commit 5705288

11 files changed

Lines changed: 428 additions & 179 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,5 @@ checkpoints/
176176
wandb/
177177

178178
dataset/
179+
generation_results/
180+
backups/

README.md

Lines changed: 22 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,44 @@
1-
# EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework
1+
# RAVQA
22

33
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/EasyR1)](https://github.com/hiyouga/EasyR1/stargazers)
44
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
55

6-
This project is a clean fork of the original [veRL](https://github.com/volcengine/verl) project to support vision language models, we thank all the authors for providing such a high-performance RL training framework.
6+
## Document VQA
77

8-
EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://arxiv.org/abs/2409.19256)** and the latest release of **[vLLM](https://github.com/vllm-project/vllm)**'s SPMD mode.
8+
### Dataset Preprocessing
99

10-
## Features
10+
#### Corpus Building
1111

12-
- Supported models
13-
- Llama3/Qwen2/Qwen2.5 language models
14-
- Qwen2/Qwen2.5-VL vision language models
15-
- DeepSeek-R1 distill models
12+
Change the raw data path and the target path in `rag_serving/build_corpus.py`
1613

17-
- Supported algorithms
18-
- GRPO
19-
- Reinforce++
20-
- ReMax
21-
- RLOO
22-
23-
- Supported datasets
24-
- Any text, vision-text dataset in a [specific format](#custom-dataset)
25-
26-
- Supported tricks
27-
- Padding-free training
28-
- Resuming from checkpoint
29-
- Wandb & SwanLab & Mlflow & Tensorboard tracking
30-
31-
## Requirements
32-
33-
### Software Requirements
34-
35-
- Python 3.9+
36-
- transformers>=4.51.0
37-
- flash-attn>=2.4.3
38-
- vllm>=0.8.3
39-
40-
We provide a [Dockerfile](./Dockerfile) to easily build environments.
41-
42-
We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1.
43-
44-
```bash
45-
docker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0
46-
```
47-
48-
### Hardware Requirements
49-
50-
\* *estimated*
51-
52-
| Method | Bits | 1.5B | 3B | 7B | 32B |
53-
| ------------------------ | ---- | ------ | ------ | ------ | ------- |
54-
| GRPO Full Fine-Tuning | AMP | 2*24GB | 4*40GB | 8*40GB | 16*80GB |
55-
| GRPO Full Fine-Tuning | BF16 | 1*24GB | 1*40GB | 4*40GB | 8*80GB |
56-
57-
> [!NOTE]
58-
> Use `worker.actor.fsdp.torch_dtype=bf16` and `worker.actor.optim.strategy=adamw_bf16` to enable bf16 training.
59-
>
60-
> We are working hard to reduce the VRAM in RL training, LoRA support will be integrated in next updates.
61-
62-
## Tutorial: Run Qwen2.5-VL GRPO on [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) Dataset in Just 3 Steps
63-
64-
![image](assets/qwen2_5_vl_7b_geo.png)
65-
66-
### Installation
67-
68-
```bash
69-
git clone https://github.com/hiyouga/EasyR1.git
70-
cd EasyR1
71-
pip install -e .
72-
```
73-
74-
### GRPO Training
75-
76-
```bash
77-
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
78-
```
79-
80-
### Merge Checkpoint in Hugging Face Format
81-
82-
```bash
83-
python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_step_1/actor
84-
```
85-
86-
> [!TIP]
87-
> If you encounter issues with connecting to Hugging Face, consider using `export HF_ENDPOINT=https://hf-mirror.com`.
88-
>
89-
> If you want to use SwanLab logger, consider using `bash examples/qwen2_5_vl_7b_geo3k_swanlab.sh`.
90-
91-
## Custom Dataset
92-
93-
Please refer to the example datasets to prepare your own dataset.
94-
95-
- Text dataset: https://huggingface.co/datasets/hiyouga/math12k
96-
- Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
97-
- Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
98-
99-
## How to Understand GRPO in EasyR1
100-
101-
![image](assets/easyr1_grpo.png)
102-
103-
- To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.16.1/en/grpo_trainer).
104-
105-
## How to Run 70B+ Model in Multi-node Environment
106-
107-
1. Start the Ray head node.
108-
109-
```bash
110-
ray start --head --port=6379 --dashboard-host=0.0.0.0
111-
```
112-
113-
2. Start the Ray worker node and connect to the head node.
114-
115-
```bash
116-
ray start --address=<head_node_ip>:6379
14+
```shell
15+
python rag_serving/build_corpus.py
11716
```
11817

119-
3. Check the Ray resource pool.
18+
#### Image Index Building
12019

121-
```bash
122-
ray status
20+
```shell
21+
python index_builder.py --retrieval_method vdr-2b-v1 --model_path llamaindex/vdr-2b-v1 --corpus_path /scratch-scc/projects/scc_ulsb_fe/yang/images_corpus/images.parquet --save_dir /scratch-scc/projects/scc_ulsb_fe/yang/images_index --max_length 512 --batch_size 128 --faiss_type Flat --index_modal image --sentence_transformer --save_embedding
12322
```
12423

125-
4. Run training script on the Ray head node only.
126-
127-
```bash
128-
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
129-
```
130-
131-
See the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for more details about multi-node training and Ray debugger.
132-
133-
## Other Baselines
134-
135-
We also reproduced the following two baselines of the [R1-V](https://github.com/deep-agent/R1-V) project.
136-
- [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem.
137-
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem.
138-
139-
## Awesome Work using EasyR1
140-
141-
- **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1)
142-
- **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749)
143-
- **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520)
144-
- **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470)
145-
- **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward. [![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
146-
- **NoisyRollout**: Reinforcing Visual Reasoning with Data Augmentation. [![[code]](https://img.shields.io/github/stars/John-AI-Lab/NoisyRollout)](https://github.com/John-AI-Lab/NoisyRollout) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.13055-blue)](https://arxiv.org/pdf/2504.13055)
147-
- **GUI-R1**: A Generalist R1-Style Vision-Language Action Model For GUI Agents. [![[code]](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)](https://github.com/ritzz-ai/GUI-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.10458-blue)](https://arxiv.org/abs/2504.10458)
148-
149-
## TODO
150-
151-
- Support LoRA (high priority).
152-
- Support ulysses parallelism for VLMs (middle priority).
153-
- Support more VLM architectures.
154-
155-
> [!NOTE]
156-
> We will not provide scripts for supervised fine-tuning and inference in this project. If you have such requirements, we recommend using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
157-
158-
### Known bugs
159-
160-
These features are temporarily disabled for now, we plan to fix them one-by-one in the future updates.
161-
162-
- Vision language models are not compatible with ulysses parallelism yet.
24+
### Launch RL
16325

164-
## Discussion Group
26+
#### Tool Environment Serving
16527

166-
👋 Join our [WeChat group](assets/wechat.jpg).
28+
1. Get the IP address of the server
16729

168-
## FAQs
30+
```shell
31+
hostname --ip-address
32+
```
16933

170-
> ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
34+
2. Start serving
17135

172-
Increase the `data.max_prompt_length` or reduce the `data.max_pixels`.
36+
```shell
37+
python rag_serving/serving.py --config rag_serving/serving_config.yaml --num_retriever 4 --port 42354
38+
```
17339

174-
> RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
40+
#### RL Training
17541

176-
Reduce the `worker.rollout.gpu_memory_utilization` and enable `worker.actor.offload.offload_params`.
17742

178-
> RuntimeError: 0 active drivers ([]). There should only be one.
17943

180-
Uninstall `deepspeed` from the current python environment.
44+
## General VQA
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class vLLMRolloutAgent(vLLMRollout, ImageProcessMixin):
409409
sampling_params=self.sampling_params,
410410
use_tqdm=False
411411
)
412-
pydevd_pycharm.settrace('47.76.117.131', port=47508, stdoutToServer=True, stderrToServer=True)
412+
# pydevd_pycharm.settrace('47.76.117.131', port=47508, stdoutToServer=True, stderrToServer=True)
413413
search_queries = []
414414
search_indices = []
415415
search_doc_ids = []
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=EasyR1-qwen2p5VL-7b-DocAgent
3+
#SBATCH --nodes=2
4+
#SBATCH --mem=450G
5+
#SBATCH --mail-user=tianyu.yang@uni-goettingen.de
6+
#SBATCH --mail-type=all
7+
#SBATCH --cpus-per-task=64
8+
#SBATCH -p kisski
9+
#SBATCH --gpus-per-node=4
10+
#SBATCH -t 48:00:00
11+
#SBATCH --output=slurm-%j.out
12+
#SBATCH --error=slurm-%j.err
13+
#############module load cuda/12.2.1
14+
############SBATCH --constraint=80gb
15+
################SBATCH --mem=500G
16+
17+
set -x
18+
#export VLLM_ATTENTION_BACKEND=XFORMERS
19+
20+
MODEL_PATH=checkpoints/EasyR1/global_step_355/actor/huggingface # replace it with your local file path
21+
WANDB_API_KEY=a3b3f7b7962a8b549c4635ee3a03944d554f1a10
22+
ROLLOUT_NAME=vllm_agent
23+
SEARCH_TOP_N=1
24+
SEARCH_URL=http://10.241.148.102:42354
25+
LIMIT_IMAGES=15
26+
MAX_RESPONSE_LENGTH=15000
27+
MAX_PROMPT_LENGTH=1024
28+
ROLLOUT_MAX_NUM_BATCHED_TOKENS=16024
29+
TENSOR_PARALLEL_SIZE=2
30+
PROMPT_KEY=question
31+
ROLLOUT_BATCH_SIZE=128
32+
ROLLOUT_N=1
33+
VAL_BATCH_SIZE=-1
34+
TEMPERATURE=0.2
35+
TEST_DATA_PATH=/mnt/vast-kisski/projects/kisski-sub-doc-understanding/EasyR1/dataset/test/feta.parquet
36+
37+
CONFIG_PATH=/mnt/vast-kisski/projects/kisski-sub-doc-understanding/EasyR1/examples/generation_config.yaml
38+
SAVE_PATH=/mnt/vast-kisski/projects/kisski-sub-doc-understanding/EasyR1/generation_results/qwen2_5_vl_7b_doc_agent_test
39+
40+
if [ "$WANDB_API_KEY" != "None" ]; then
41+
wandb login --relogin $WANDB_API_KEY
42+
fi
43+
44+
# make output directory
45+
if [ ! -d "$SAVE_PATH" ]; then
46+
mkdir -p $SAVE_PATH
47+
fi
48+
49+
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
50+
nodes_array=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
51+
52+
head_node=${nodes_array[0]}
53+
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
54+
55+
if [[ "$head_node_ip" == *" "* ]]; then
56+
IFS=' ' read -ra ADDR <<<"$head_node_ip"
57+
if [[ ${#ADDR[0]} -gt 16 ]]; then
58+
head_node_ip=${ADDR[1]}
59+
else
60+
head_node_ip=${ADDR[0]}
61+
fi
62+
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
63+
fi
64+
65+
port=6379
66+
ip_head=$head_node_ip:$port
67+
export ip_head
68+
echo "IP Head: $ip_head"
69+
70+
71+
echo "StartingHEAD at $head_node"
72+
srun --nodes=1 --ntasks=1 -w "$head_node" /bin/bash -c \
73+
"source /user/yang28/u14705/.bashrc && source /mnt/vast-kisski/projects/kisski-sub-doc-understanding/miniconda3/bin/activate EasyR1 \
74+
&& ray start --head --node-ip-address="$head_node_ip" --port=$port \
75+
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --include-dashboard true --dashboard-host 0.0.0.0 --dashboard-port 8265 --block" &
76+
# optional, though may be useful in certain versions of Ray < 1.0.
77+
sleep 10
78+
79+
# number of nodes other than the head node
80+
worker_num=$((SLURM_JOB_NUM_NODES - 1))
81+
#export worker_num = 1
82+
83+
for ((i = 1; i <= worker_num; i++)); do
84+
node_i=${nodes_array[$i]}
85+
echo "Starting WORKER $i at $node_i"
86+
srun --nodes=1 --ntasks=1 -w "$node_i" /bin/bash -c \
87+
"source /user/yang28/u14705/.bashrc && source /mnt/vast-kisski/projects/kisski-sub-doc-understanding/miniconda3/bin/activate EasyR1 \
88+
&& ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block" &
89+
sleep 5
90+
done
91+
92+
93+
srun --overlap --nodes=1 --ntasks=1 -w "$head_node" /bin/bash -c \
94+
"source /user/yang28/u14705/.bashrc && source /mnt/vast-kisski/projects/kisski-sub-doc-understanding/miniconda3/bin/activate EasyR1 \
95+
&& python -m verl.trainer.main_generation \
96+
config=${CONFIG_PATH} \
97+
data.test_files=${TEST_DATA_PATH} \
98+
data.prompt_key=${PROMPT_KEY} \
99+
data.format_prompt=./examples/format_prompt/doc_agent.py \
100+
data.max_response_length=${MAX_RESPONSE_LENGTH} \
101+
data.max_prompt_length=${MAX_PROMPT_LENGTH} \
102+
data.rollout_batch_size=${ROLLOUT_BATCH_SIZE} \
103+
worker.actor.model.model_path=${MODEL_PATH} \
104+
worker.rollout.tensor_parallel_size=${TENSOR_PARALLEL_SIZE} \
105+
worker.rollout.name=${ROLLOUT_NAME} \
106+
worker.rollout.n=${ROLLOUT_N} \
107+
worker.rollout.temperature=${TEMPERATURE} \
108+
worker.rollout.max_num_batched_tokens=${ROLLOUT_MAX_NUM_BATCHED_TOKENS} \
109+
worker.rollout.top_n=${SEARCH_TOP_N} \
110+
worker.rollout.search_url=${SEARCH_URL} \
111+
worker.rollout.limit_images=${LIMIT_IMAGES} \
112+
worker.reward.score_function=./examples/score_function/doc_agent.py:compute_score \
113+
trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \
114+
trainer.nnodes=${SLURM_NNODES} \
115+
trainer.save_checkpoint_path=${SAVE_PATH}"
116+
# trainer.load_checkpoint_path=/mnt/vast-kisski/projects/kisski-sub-doc-understanding/EasyR1/checkpoints/qwen2_5_vl_7b_doc_agent/global_step_160"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
system_prompt = "You are a helpful assistant designed to answer user questions based on a user-provided multi-page document. The document can not be input directly with the question, you must reason step by step to determine how to obtain evidence document pages by optimally utilizing tools and analyze the relevant content in the obtained document pages to precisely answer user's question. Your reasoning process MUST BE enclosed within <think> </think> tags. Your answer MUST BE enclosed within <answer> </answer> tags. In the last part of the answer, the final exact answer is enclosed within \\boxed{{}} with latex format. The available tools include a **search tool** and a **fetch tool**. During reasoning, you can invoke either the search tool by generating <search> your search query here </search> to retrieve document pages most relevant to your search query or or the fetch tool by generating <fetch> page number </fetch> to obtain a specific document page. For example, your response could be in the format of \'<think> your reasoning process </think> <search> search query </search>\', or \'<think> your reasoning process </think> <fetch> page number </fetch>\', or \'<think> your reasoning process </think> <answer> your answer here. The final answer is \\[ \\boxed{{answer here}} \\] </answer>\'. After invoking a tool, the user will return obtained document pages inside <result> </result> tags to you.\n\n**Important constraints**:\n- If there is no answer found in the document, respond with <answer> The final answer is \\[ \\boxed{{The problem is not answerable}} \\] </answer>.\n- If multiple valid answers are found, return them separated by semicolons.\n- Only one page can be fetched at a time using the fetch tool.\n- Enrich the user question to form a good search query to get more accurate retrieval results.\n- Do not naively use the fetch tool if you don't know the specific page number of the document page that the user is asking about."
1+
system_prompt = "You are a helpful assistant designed to answer user questions based on a user-provided multi-page document. The document can not be input directly with the question, you must reason step by step to determine how to obtain evidence document pages by optimally utilizing tools and analyze the relevant content in the obtained document pages to precisely answer user's question. Your reasoning process MUST BE enclosed within <think> </think> tags. Your answer MUST BE enclosed within <answer> </answer> tags. In the last part of the answer, the final exact answer is enclosed within \\boxed{{}} with latex format. The available tools include a **search tool** and a **fetch tool**. During reasoning, you can invoke either the search tool by generating <search> your search query here </search> to retrieve document pages most relevant to your search query or or the fetch tool by generating <fetch> page number </fetch> to obtain a specific document page. For example, your response could be in the format of \'<think> your reasoning process </think> <search> search query </search>\', or \'<think> your reasoning process </think> <fetch> page number </fetch>\', or \'<think> your reasoning process </think> <answer> your answer here. The final answer is \\[ \\boxed{{answer here}} \\] </answer>\'. After invoking a tool, the user will return obtained document pages inside <result> </result> tags to you.\n\n**Important constraints**:\n- If there is no answer found in the document, respond with <answer> The final answer is \\[ \\boxed{{The problem is not answerable}} \\] </answer>.\n- If multiple valid answers are found, return them separated by semicolons.\n- Only one page can be fetched at a time using the fetch tool.\n- Enrich the user question to form a good search query to get more accurate retrieval results.\n- Do not naively use the fetch tool if you don't know the specific page number of the document page that the user is asking about."

0 commit comments

Comments
 (0)