You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Once these steps are complete, your environment is ready to begin the reinforcement learning fine-tuning process.
132
136
133
-
### 4. Start Training
137
+
### 4. Start RL Fine-Tuning
134
138
135
139
**Configure Training Parameters**
140
+
Before launching, you may need to adjust key parameters in the configuration file: `options/omnigen2_edit_rl_4machine_editscore7b_avg4.yml`.
136
141
137
-
Edit the `options/omnigen2_edit_rl.yml` configuration file, focusing on these key parameters:
138
-
-`train.global_batch_size`: Global batch size across all GPUs (num_unique_prompts_per_sampling * num_images_per_prompt)
139
-
-`train.batch_size`: Batch size per GPU (batch_size_per_forward * gradient_accumulation_steps * num_update_steps_per_sampling)
140
-
-`train.rl.num_images_per_prompt`: The number of roolout of one prompt
141
-
-`train.rl.num_unique_prompts_per_sampling`: Number of global unique prompts
142
+
Here are some important settings:
143
+
-`train.global_batch_size`: The total number of images generated across all GPUs in a single sampling phase before the policy is updated. It is calculated as `num_unique_prompts_per_sampling * num_images_per_prompt`.
144
+
-`train.batch_size`: Batch size per GPU (`batch_size_per_forward * gradient_accumulation_steps * num_update_steps_per_sampling`)
145
+
-`train.rl.num_images_per_prompt`: The number of candidate images to generate for each unique prompt.
146
+
-`train.rl.num_unique_prompts_per_sampling`: The number of unique prompts in a global batch
147
+
-`train.rl.num_update_steps_per_sampling`: The number of gradient updates to perform in each sampling phase. Set this to `> 1` to enable off-policy RL, which improves sample efficiency.
148
+
-`train.rl.batch_size_per_forward`: Batch size for each forward pass. Together with `num_update_steps_per_sampling`, it defines the total number of samples processed per policy update.
142
149
143
150
**Launch Distributed Training**
151
+
We provide scripts for both single and multi-machine distributed training based on **FSDP**.
144
152
```bash
145
-
# Singlemachine training (8*H100 GPUs)
146
-
bash scripts/train/omnigen2_edit_rl.sh
153
+
# Single-machine training (8 GPUs) using EditScore-7B as the reward model
0 commit comments