Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/train_waa_vagen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ model:
# Each env instance connects to the WAA server independently via HTTP.
envs:
- name: WAADesktop
n_envs: 8 # Number of parallel environments (= GRPO group size)
n_envs: 1 # Must be 1 (single WAA VM; use rollout.n for GRPO group size)
data_source: waa
seed: [1, 100, 1] # [start, end, step] for deterministic seeding
max_turns: 15 # Max actions per episode
Expand All @@ -72,7 +72,7 @@ algorithm:

trainer:
total_epochs: 100
n_gpus_per_node: 2 # Minimum for VLM training
n_gpus_per_node: 1 # 1 for g5.xlarge; use 4 for g5.12xlarge
micro_batch_size: 4
gradient_accumulation_steps: 2
test_freq: 5 # Evaluate every N epochs
Expand Down
4 changes: 2 additions & 2 deletions openadapt_evals/benchmarks/vm_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9122,8 +9122,8 @@ def main():
help="Model to train",
)
p_gpu_train.add_argument(
"--n-gpus", type=int, default=2,
help="Number of GPUs (default: 2)",
"--n-gpus", type=int, default=1,
help="Number of GPUs (default: 1, use 4 for g5.12xlarge)",
)
p_gpu_train.add_argument(
"--epochs", type=int, default=100,
Expand Down
84 changes: 37 additions & 47 deletions scripts/train_verl_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,10 @@ def _generate_training_config(
import json

config = {
"model": {
"name": model,
},
"envs": [
{
"name": "WAADesktop",
"n_envs": group_size,
"n_envs": 1,
"data_source": "waa",
"seed": [1, 100, 1],
"max_turns": max_turns,
Expand All @@ -308,27 +305,6 @@ def _generate_training_config(
},
}
],
"algorithm": {
"name": algorithm,
"kl_coef": 0.0,
"epsilon": 0.2,
"gamma": 1.0 if algorithm != "gigpo" else 0.95,
},
"trainer": {
"total_epochs": epochs,
"n_gpus_per_node": n_gpus,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"test_freq": 5,
"experiment_name": f"{algorithm}_waa_desktop",
"project_name": "openadapt-waa-rl",
"logger": ["console", "wandb"],
},
"rollout": {
"temperature": 0.7,
"top_p": 0.95,
"mode": "async",
},
}

# Upload config as YAML
Expand Down Expand Up @@ -396,41 +372,55 @@ def launch_training(
)

# Step 4: Launch training
# VAGEN uses verl's trainer entry point with additional env/agent config.
# The exact command may vary by VAGEN version. The config YAML provides
# the env spec; Hydra overrides configure the verl training loop.
# VAGEN uses its own entry point (vagen.main_ppo) with Hydra config.
# The env spec YAML provides the environment definition;
# Hydra overrides configure the verl training loop.
train_cmd = f"""
cd ~/verl-agent && \\
conda run -n verl-agent python3 -m verl.trainer.main_ppo \\
conda run -n verl-agent python3 -m vagen.main_ppo \\
--config-path=$HOME/verl-agent/vagen/configs \\
--config-name=vagen_multiturn \\
data.train_files={config_path} \\
data.val_files={config_path} \\
data.train_batch_size=1 \\
data.max_prompt_length=2048 \\
data.max_response_length=512 \\
data.return_raw_chat=True \\
data.return_multi_modal_inputs=True \\
algorithm.adv_estimator={algorithm} \\
algorithm.kl_ctrl.kl_coef=0.0 \\
algorithm.gamma={'0.95' if algorithm == 'gigpo' else '1.0'} \\
actor_rollout_ref.model.path={model} \\
actor_rollout_ref.model.enable_gradient_checkpointing=True \\
actor_rollout_ref.actor.optim.lr=1e-6 \\
actor_rollout_ref.actor.ppo_mini_batch_size=1 \\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\
actor_rollout_ref.actor.fsdp_config.param_offload=True \\
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\
actor_rollout_ref.rollout.name=vllm \\
actor_rollout_ref.rollout.tensor_model_parallel_size={n_gpus} \\
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\
actor_rollout_ref.rollout.enable_chunked_prefill=False \\
actor_rollout_ref.actor.ppo_mini_batch_size=64 \\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\
data.train_files=$HOME/data/verl-agent/visual/train.parquet \\
data.val_files=$HOME/data/verl-agent/visual/test.parquet \\
data.train_batch_size={group_size} \\
data.val_batch_size=128 \\
data.max_prompt_length=2048 \\
data.max_response_length=512 \\
data.return_raw_chat=True \\
data.filter_overlong_prompts=True \\
actor_rollout_ref.rollout.mode=async \\
actor_rollout_ref.rollout.n={group_size} \\
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\
actor_rollout_ref.rollout.enforce_eager=True \\
actor_rollout_ref.rollout.enable_chunked_prefill=True \\
actor_rollout_ref.rollout.multi_turn.enable=True \\
actor_rollout_ref.rollout.agent.agent_loop_config_path=$HOME/verl-agent/vagen/configs/agent.yaml \\
actor_rollout_ref.ref.fsdp_config.param_offload=True \\
trainer.n_gpus_per_node={n_gpus} \\
trainer.nnodes=1 \\
trainer.total_epochs={epochs} \\
trainer.total_training_steps={epochs} \\
trainer.test_freq=5 \\
trainer.experiment_name={algorithm}_waa_desktop \\
trainer.save_freq=25 \\
trainer.val_before_train=True \\
trainer.logger=['console','wandb'] \\
trainer.project_name=openadapt-waa-rl \\
+env_config={config_path}
trainer.experiment_name={algorithm}_waa_desktop
"""
logger.info("Launching training with %s on %d GPU(s)...", algorithm, n_gpus)
logger.info("Model: %s", model)
logger.info("WAA server: %s", waa_server)
logger.info("Evaluate server: %s", evaluate_url)
logger.info("Task: %s", task_id)

result = _ssh_run(ip, train_cmd, username=username, stream=True)
Expand Down Expand Up @@ -479,8 +469,8 @@ def main():
help="Model to train (default: Qwen/Qwen2.5-VL-3B-Instruct)",
)
parser.add_argument(
"--n-gpus", type=int, default=2,
help="Number of GPUs per node (default: 2)",
"--n-gpus", type=int, default=1,
help="Number of GPUs per node (default: 1, use 4 for g5.12xlarge)",
)
parser.add_argument(
"--epochs", type=int, default=100,
Expand Down