Skip to content

Commit 2530467

Browse files
committed
update
1 parent da9e076 commit 2530467

25 files changed

Lines changed: 1247 additions & 104 deletions

examples/OmniGen2-RL/README.md

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ Downlaod the official RL training data from [EditScore-RL-Data](https://huggingf
7979
2. Create Meta File
8080
The uploaded dataset uses relative image paths. Run the following script to convert them to absolute paths based on your local environment:
8181
```bash
82+
# Then
8283
python scripts/data/process_jsonl.py --input /path/to/EditScore-RL-Data/rl.jsonl --output /path/to/EditScore-RL-Data/rl_abs.jsonl --base-path /path/to/EditScore-RL-Data
84+
85+
# Due to the limitation of base model (OmniGen2), we discard text change and portrait beautification, as these tasks harm RL training.
86+
python scripts/data/extract_9_tasks.py --input_path /path/to/EditScore-RL-Data/rl_abs.jsonl --output_path /path/to/EditScore-RL-Data/rl_abs_9tasks.jsonl
8387
```
8488
3. Configure the Data Path
8589
Specify the path to your processed `.jsonl` file in the data configuration located at `data_configs/train/example/edit/all.yml`.
@@ -89,7 +93,7 @@ ratio_type: inside_ratio
8993

9094
data:
9195
-
92-
path: '/path/to/EditScore-RL-Data/rl_abs.jsonl' # <-- Ensure this path is correct
96+
path: '/path/to/EditScore-RL-Data/rl_abs_9tasks.jsonl' # <-- Ensure this path is correct
9397
type: 'edit'
9498
ratio: !!float 1
9599
```
@@ -130,29 +134,53 @@ python reward_server/scripts/utils/reward_server_sanity_check.py --config_path=r
130134
```
131135
Once these steps are complete, your environment is ready to begin the reinforcement learning fine-tuning process.
132136

133-
### 4. Start Training
137+
### 4. Start RL Fine-Tuning
134138

135139
**Configure Training Parameters**
140+
Before launching, you may need to adjust key parameters in the configuration file: `options/omnigen2_edit_rl_4machine_editscore7b_avg4.yml`.
136141

137-
Edit the `options/omnigen2_edit_rl.yml` configuration file, focusing on these key parameters:
138-
- `train.global_batch_size`: Global batch size across all GPUs (num_unique_prompts_per_sampling * num_images_per_prompt)
139-
- `train.batch_size`: Batch size per GPU (batch_size_per_forward * gradient_accumulation_steps * num_update_steps_per_sampling)
140-
- `train.rl.num_images_per_prompt`: The number of roolout of one prompt
141-
- `train.rl.num_unique_prompts_per_sampling`: Number of global unique prompts
142+
Here are some important settings:
143+
- `train.global_batch_size`: The total number of images generated across all GPUs in a single sampling phase before the policy is updated. It is calculated as `num_unique_prompts_per_sampling * num_images_per_prompt`.
144+
- `train.batch_size`: Batch size per GPU (`batch_size_per_forward * gradient_accumulation_steps * num_update_steps_per_sampling`)
145+
- `train.rl.num_images_per_prompt`: The number of candidate images to generate for each unique prompt.
146+
- `train.rl.num_unique_prompts_per_sampling`: The number of unique prompts in a global batch
147+
- `train.rl.num_update_steps_per_sampling`: The number of gradient updates to perform in each sampling phase. Set this to `> 1` to enable off-policy RL, which improves sample efficiency.
148+
- `train.rl.batch_size_per_forward`: Batch size for each forward pass. Together with `num_update_steps_per_sampling`, it defines the total number of samples processed per policy update.
142149

143150
**Launch Distributed Training**
151+
We provide scripts for both single and multi-machine distributed training based on **FSDP**.
144152
```bash
145-
# Single machine training (8*H100 GPUs)
146-
bash scripts/train/omnigen2_edit_rl.sh
153+
# Single-machine training (8 GPUs) using EditScore-7B as the reward model
154+
bash scripts/train/omnigen2_edit_rl_single_machine_editscore7b.sh
147155

148-
# Multi-machine distributed training
156+
# Multi-machine training (e.g., 4 machines with 8 GPUs each) using EditScore-7B (Avg@4)
157+
bash scripts/train/omnigen2_edit_rl_4machine_editscore7b_avg4.sh
149158
```
150159

151-
> **⚠️ Training Configuration Key Points**
152-
>
153-
> **Reward Server IP**: Ensure the `REWARD_SERVER_IP` environment variable in training scripts points to the correct reward server address
160+
### 4. Training Outputs and Monitoring
161+
All training artifacts, including logs and model checkpoints, are saved to the `experiments/` directory.
154162

163+
### 5. Evaluate your RL Fine-Tuned Model
164+
After training, you must convert the FSDP-saved checkpoint (`.bin`) into the standard Hugging Face format before you can use it for inference.
155165

156-
### 4. Training Outputs and Monitoring
166+
#### Step 1: Convert the Checkpoint
167+
We provide a script to automatically handle the conversion from the distributed FSDP format to the standard Hugging Face format (`.bin`).
168+
Run the following command, replacing the arguments with your experiment's details:
157169

158-
Logs and saved model checkpoints in `experiments/`
170+
```shell
171+
bash scripts/misc/convert_dist_ckpt_to_hf_format.sh [EXPERIMENT_NAME] [STEP_NUMBER]
172+
```
173+
- [EXPERIMENT_NAME]: The name of your training experiment (e.g., omnigen2_edit_rl_single_machine_editscore7b).
174+
- [STEP_NUMBER]: The specific training step of the checkpoint you wish to evaluate (e.g., 500).
175+
176+
This will create a new directory containing the converted model weights in the standard format, ready for inference.
177+
178+
#### Step 2: Run Evaluation on GEdit-Bench
179+
Once the checkpoint is converted, you can benchmark its performance. We provide evaluation scripts tailored for GEdit-Bench.
180+
181+
You can use our example scripts as a template. Simply copy one and modify the internal paths to point to your newly converted model checkpoint.
182+
```shell
183+
# Run evaluation for the converted model from step 500
184+
bash evaluation/GEdit-Bench/omnigen2_edit_rl_single_machine_editscore7b_step500.sh
185+
```
186+
By comparing the results to the baseline model's performance, you can quantify the improvements achieved through RL fine-tuning with EditScore.

examples/OmniGen2-RL/data_configs/train/example/edit/all.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ ratio_type: inside_ratio
22

33
data:
44
-
5-
path: '/path/to/EditScore-RL-Data/rl_abs.jsonl'
5+
# path: '/path/to/EditScore-RL-Data/rl_abs_9tasks.jsonl'
6+
path: '/share/project/chenyuan/data2/EditScore-RL-Data-v4/rl_abs_9tasks.jsonl'
67
type: 'edit'
78
ratio: !!float 1
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# !/bin/bash
2+
SHELL_FOLDER=$(cd "$(dirname "$0")";pwd)
3+
cd $(dirname $SHELL_FOLDER)
4+
cd ../
5+
6+
source "$(dirname $(which conda))/../etc/profile.d/conda.sh"
7+
conda activate py3.12+pytorch2.7.1+cu126
8+
9+
RANK=0
10+
MASTER_ADDR=1
11+
MASTER_PORT=29500
12+
WORLD_SIZE=1
13+
14+
# 处理命名参数
15+
while [[ $# -gt 0 ]]; do
16+
case "$1" in
17+
--rank=*)
18+
RANK="${1#*=}"
19+
shift
20+
;;
21+
--master_addr=*)
22+
MASTER_ADDR="${1#*=}"
23+
shift
24+
;;
25+
--master_port=*)
26+
MASTER_PORT="${1#*=}"
27+
shift
28+
;;
29+
--world_size=*)
30+
WORLD_SIZE="${1#*=}"
31+
shift
32+
;;
33+
*)
34+
echo "未知参数: $1"
35+
shift
36+
;;
37+
esac
38+
done
39+
40+
# 输出配置
41+
echo "RANK: $RANK"
42+
echo "MASTER_ADDR: $MASTER_ADDR"
43+
echo "MASTER_PORT: $MASTER_PORT"
44+
echo "WORLD_SIZE: $WORLD_SIZE"
45+
46+
global_shift_index=0
47+
total_num_images=606
48+
49+
num_gpus_per_machine=$(python -c "import torch; print(torch.cuda.device_count())")
50+
# Calculate images per machine, rounding up to ensure all data is covered
51+
num_images_per_machine=$(( (total_num_images + WORLD_SIZE - 1) / WORLD_SIZE ))
52+
shift_index=$((RANK * num_images_per_machine))
53+
54+
if [ $((total_num_images - shift_index)) -lt $num_images_per_machine ]; then
55+
num_images_per_machine=$((total_num_images - shift_index))
56+
fi
57+
58+
# Calculate base number of images per GPU (for first 7 GPUs)
59+
num_images_per_gpu=$(( (num_images_per_machine + num_gpus_per_machine - 1) / num_gpus_per_machine ))
60+
61+
text_guidance_scale=5.0
62+
image_guidance_scale=1.5
63+
64+
for ((i=0; i<num_gpus_per_machine; i++)); do
65+
if [ $i -lt $((num_gpus_per_machine - 1)) ]; then
66+
# First 7 GPUs process equal amounts
67+
start_idx=$((global_shift_index + i * num_images_per_gpu + shift_index))
68+
end_idx=$((start_idx + num_images_per_gpu))
69+
else
70+
# Last GPU processes remaining data
71+
start_idx=$((global_shift_index + (num_gpus_per_machine - 1) * num_images_per_gpu + shift_index))
72+
end_idx=$((global_shift_index + shift_index + num_images_per_machine))
73+
fi
74+
echo ${start_idx} ${end_idx}
75+
76+
CUDA_VISIBLE_DEVICES=${i} WORLD_SIZE=1 nohup accelerate launch --num_processes 1 --num_machines 1 \
77+
evaluation/GEdit-Bench/inference.py \
78+
--load_from_pipeline \
79+
--pipeline_path OmniGen2/OmniGen2 \
80+
--transformer_lora_path experiments/omnigen2_edit_rl_single_machine_editscore7b/checkpoint-500/transformer_lora \
81+
--num_inference_step 50 \
82+
--height 1024 \
83+
--width 1024 \
84+
--text_guidance_scale ${text_guidance_scale} \
85+
--image_guidance_scale ${image_guidance_scale} \
86+
--time_shift_base_res 168 \
87+
--negative_prompt "" \
88+
--use_ori_neg_prompt_template \
89+
--scheduler "euler" \
90+
--result_dir evaluation/GEdit-Bench/results/OmniGen2/results_ts${text_guidance_scale}_ig${image_guidance_scale}_16samples \
91+
--start_index ${start_idx} --end_index ${end_idx} \
92+
--num_samples 16 \
93+
> logs/gedit_OmniGen2_ts${text_guidance_scale}_ig${image_guidance_scale}_16samples_${start_idx}_${end_idx}.log 2>&1 &
94+
done

examples/OmniGen2-RL/nccl_logs/.gitkeep

Whitespace-only changes.

examples/OmniGen2-RL/options/omnigen2_edit_rl.yml renamed to examples/OmniGen2-RL/options/omnigen2_edit_rl_4machine_editscore7b_avg4.yml

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
name: omnigen2_edit_rl
1+
name: omnigen2_edit_rl_4machine_editscore7b_avg4
22

33
seed: 2233
44
device_specific_seed: true
55
workder_specific_seed: true
66

7+
reward_server_config: reward_server/server_configs/editscore_7B_avg4.yml
8+
79
data:
810
data_path: data_configs/train/example/train.yml
911
use_chat_template: true
@@ -40,14 +42,11 @@ transport:
4042
dynamic_time_shift: true
4143

4244
train:
43-
44-
# global_batch_size: 576
45-
global_batch_size: 144
45+
global_batch_size: 576
4646
batch_size: 18
4747
gradient_accumulation_steps: 1
48-
49-
# num_train_epochs: 4
50-
max_train_steps: 5000
48+
49+
max_train_steps: 1000
5150

5251
dataloader_num_workers: 12
5352

@@ -85,8 +84,7 @@ train:
8584
lora_dropout: 0
8685

8786
rl:
88-
# num_unique_prompts_per_sampling: 48
89-
num_unique_prompts_per_sampling: 12
87+
num_unique_prompts_per_sampling: 48
9088
num_update_steps_per_sampling: 2
9189
batch_size_per_forward: 9
9290
num_images_per_prompt: 12
@@ -107,12 +105,12 @@ train:
107105
server_type: vlm
108106
use_ori_neg_prompt_template: true
109107
time_shift_base_res: 168
108+
policy_loss_reweighting: false
110109

111110
val:
112111
train_visualization_interval: 5
113112
num_train_visualization_samples: 3
114113

115-
116114
logger:
117115
log_with: [wandb, tensorboard]
118116
# log_with: ~
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
name: omnigen2_edit_rl_4machine_editscore7b_avg8
2+
3+
seed: 2233
4+
device_specific_seed: true
5+
workder_specific_seed: true
6+
7+
reward_server_config: reward_server/server_configs/editscore_7B_avg8.yml
8+
9+
data:
10+
data_path: data_configs/train/example/train.yml
11+
use_chat_template: true
12+
maximum_text_tokens: 888
13+
prompt_dropout_prob: !!float 0.0
14+
ref_img_dropout_prob: !!float 0.0
15+
max_output_pixels: 262144 # 512 * 512
16+
max_input_pixels: [262144, 262144, 262144, 262144] # [512 * 512, 512 * 512, 512 * 512, 512 * 512]
17+
max_side_length: 2048
18+
19+
model:
20+
pretrained_vae_model_name_or_path: black-forest-labs/FLUX.1-dev
21+
pretrained_text_encoder_model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
22+
pretrained_model_path: pretrained_models/OmniGen2/transformer/pytorch_model.bin
23+
24+
arch_opt:
25+
patch_size: 2
26+
in_channels: 16
27+
hidden_size: 2520
28+
num_layers: 32
29+
num_refiner_layers: 2
30+
num_attention_heads: 21
31+
num_kv_heads: 7
32+
multiple_of: 256
33+
norm_eps: !!float 1e-05
34+
axes_dim_rope: [40, 40, 40]
35+
axes_lens: [10000, 10000, 10000]
36+
text_feat_dim: 2048
37+
timestep_scale: !!float 1000
38+
39+
transport:
40+
snr_type: lognorm
41+
do_shift: true
42+
dynamic_time_shift: true
43+
44+
train:
45+
global_batch_size: 576
46+
batch_size: 18
47+
gradient_accumulation_steps: 1
48+
49+
max_train_steps: 1000
50+
51+
dataloader_num_workers: 12
52+
53+
# Optimizer
54+
learning_rate: !!float 4e-4
55+
scale_lr: false
56+
lr_scheduler: timm_constant_with_warmup
57+
warmup_t: 0
58+
warmup_lr_init: 1e-7
59+
warmup_prefix: true
60+
t_in_epochs: false
61+
62+
# resume_from_checkpoint:
63+
64+
use_8bit_adam: false
65+
adam_beta1: 0.9
66+
adam_beta2: 0.95
67+
adam_weight_decay: !!float 0.01
68+
adam_epsilon: !!float 1e-08
69+
max_grad_norm: 1
70+
71+
gradient_checkpointing: true
72+
73+
set_grads_to_none: true
74+
75+
# Misc
76+
allow_tf32: false
77+
mixed_precision: 'bf16'
78+
79+
ema_decay: 0.0
80+
81+
lora_ft: true
82+
lora_rank: 32
83+
lora_alpha: 64
84+
lora_dropout: 0
85+
86+
rl:
87+
num_unique_prompts_per_sampling: 48
88+
num_update_steps_per_sampling: 2
89+
batch_size_per_forward: 9
90+
num_images_per_prompt: 12
91+
sigma_coef: 0.7
92+
negative_prompt: ""
93+
num_inference_step: 20
94+
max_sequence_length: 1024
95+
text_guidance_scale: 4
96+
image_guidance_scale: 2
97+
cfg_range_start: 0.0
98+
cfg_range_end: 0.6
99+
train_timesteps_fraction: 0.6
100+
reuse_samples_nums: 1
101+
clip_range: [!!float 1e-4, !!float 5e-4]
102+
adv_clip_max: !!float 5
103+
kl_loss_weight: !!float 0.04
104+
apply_cfg_in_training: true
105+
server_type: vlm
106+
use_ori_neg_prompt_template: true
107+
time_shift_base_res: 168
108+
policy_loss_reweighting: false
109+
110+
val:
111+
train_visualization_interval: 5
112+
num_train_visualization_samples: 3
113+
114+
logger:
115+
log_with: [wandb, tensorboard]
116+
# log_with: ~
117+
118+
checkpointing_steps: 50
119+
checkpoints_total_limit: ~
120+
121+
cache_dir:
122+
resume_from_checkpoint: latest

0 commit comments

Comments
 (0)