|
| 1 | +#!/usr/bin/env bash |
| 2 | +set -euo pipefail |
| 3 | +set -x |
| 4 | + |
| 5 | +# Joint training recipe (DAPO + Jackpot) on GSM8K. |
| 6 | +# - Two namespaces (`large`, `small`) are trained jointly. |
| 7 | +# - Rollout/logprob provider is `small`. |
| 8 | +# - Pairwise KL in `trainer.topologies` couples updates between both models. |
| 9 | + |
| 10 | +# Dataset preparation (from docs/examples/gsm8k_example.rst): |
| 11 | +# cd examples/data_preprocess |
| 12 | +# python3 gsm8k.py --local_save_dir ~/data/gsm8k |
| 13 | +DATA_ROOT=${DATA_ROOT:-$HOME/data/gsm8k} |
| 14 | +TRAIN_FILE=${TRAIN_FILE:-$DATA_ROOT/train.parquet} |
| 15 | +VAL_FILE=${VAL_FILE:-$DATA_ROOT/test.parquet} |
| 16 | + |
| 17 | +if [[ ! -f "$TRAIN_FILE" || ! -f "$VAL_FILE" ]]; then |
| 18 | + echo "Missing GSM8K parquet files under $DATA_ROOT." |
| 19 | + echo "Run: cd examples/data_preprocess && python3 gsm8k.py --local_save_dir $DATA_ROOT" |
| 20 | + exit 1 |
| 21 | +fi |
| 22 | + |
| 23 | +LARGE_MODEL=${LARGE_MODEL:-Qwen/Qwen3-0.6B-Base} |
| 24 | +SMALL_MODEL=${SMALL_MODEL:-Qwen/Qwen2.5-0.5B} |
| 25 | + |
| 26 | +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-64} |
| 27 | +PPO_MINI_BATCH_SIZE=${PPO_MINI_BATCH_SIZE:-64} |
| 28 | +ROLLOUT_N=${ROLLOUT_N:-8} |
| 29 | + |
| 30 | +# DAPO clipping / filtering knobs. |
| 31 | +CLIP_RATIO_LOW=${CLIP_RATIO_LOW:-0.2} |
| 32 | +CLIP_RATIO_HIGH=${CLIP_RATIO_HIGH:-0.28} |
| 33 | +ENABLE_FILTER_GROUPS=${ENABLE_FILTER_GROUPS:-True} |
| 34 | +FILTER_GROUPS_METRIC=${FILTER_GROUPS_METRIC:-acc} |
| 35 | +MAX_NUM_GEN_BATCHES=${MAX_NUM_GEN_BATCHES:-10} |
| 36 | + |
| 37 | +# DAPO overlong buffer knobs. |
| 38 | +ENABLE_OVERLONG_BUFFER=${ENABLE_OVERLONG_BUFFER:-True} |
| 39 | +OVERLONG_BUFFER_LEN=${OVERLONG_BUFFER_LEN:-2048} |
| 40 | +OVERLONG_PENALTY_FACTOR=${OVERLONG_PENALTY_FACTOR:-1.0} |
| 41 | + |
| 42 | +# Jackpot knobs (OBRS correction): |
| 43 | +# - actor.use_jackpot=True enables Jackpot token reweighting. |
| 44 | +# - actor.jackpot_use_latest_logits=True recomputes overlap using current policy logits. |
| 45 | +# - actor.jackpot_log_probs_to_keep controls top-k width used for overlap mass estimation. |
| 46 | +# - actor.jackpot_lambda controls acceptance strictness (higher => stricter correction). |
| 47 | +# - actor.jackpot_clip_ratio caps Jackpot importance weights for stability. |
| 48 | +# - actor.jackpot_use_topk_renorm=True renormalizes overlap mass inside kept top-k. |
| 49 | +# - rollout.calculate_log_probs=True and rollout.log_probs_to_keep must stay enabled for Jackpot. |
| 50 | +JACKPOT_LOGPROBS_TO_KEEP=${JACKPOT_LOGPROBS_TO_KEEP:-20} |
| 51 | +JACKPOT_LAMBDA=${JACKPOT_LAMBDA:-1.0} |
| 52 | +JACKPOT_CLIP_RATIO=${JACKPOT_CLIP_RATIO:-16.0} |
| 53 | + |
| 54 | +FWD_KL_SMALL=${FWD_KL_SMALL:-0.1} |
| 55 | +REV_KL_LARGE=${REV_KL_LARGE:-0.00} |
| 56 | + |
| 57 | +python3 -m recipe.dapo.main_dapo \ |
| 58 | + data.train_files="$TRAIN_FILE" \ |
| 59 | + data.val_files="$VAL_FILE" \ |
| 60 | + data.train_batch_size="${TRAIN_BATCH_SIZE}" \ |
| 61 | + data.max_prompt_length=1024 \ |
| 62 | + data.max_response_length=4096 \ |
| 63 | + data.filter_overlong_prompts=True \ |
| 64 | + data.truncation=error \ |
| 65 | + +ray_kwargs.ray_init.object_store_memory=144000000000 \ |
| 66 | + trainer.namespace=large \ |
| 67 | + trainer.train_namespaces=[large,small] \ |
| 68 | + trainer.rollout_from=small \ |
| 69 | + trainer.critic_warmup=0 \ |
| 70 | + trainer.logger=[console,wandb] \ |
| 71 | + trainer.project_name=jackpot_gsm8k_release_dapo \ |
| 72 | + trainer.experiment_name=qwen_dual_kl_dapo_gsm8k \ |
| 73 | + trainer.n_gpus_per_node=1 \ |
| 74 | + trainer.nnodes=1 \ |
| 75 | + trainer.save_freq=16 \ |
| 76 | + trainer.test_freq=16 \ |
| 77 | + trainer.total_epochs=8 \ |
| 78 | + trainer.max_actor_ckpt_to_keep=1 \ |
| 79 | + trainer.val_before_train=False \ |
| 80 | + trainer.validation_use_train_namespace=True \ |
| 81 | + trainer.resource_pool_name=global_pool \ |
| 82 | + actor_rollout_ref.model.path="${LARGE_MODEL}" \ |
| 83 | + actor_rollout_ref.actor.optim.lr=1e-6 \ |
| 84 | + actor_rollout_ref.model.use_remove_padding=True \ |
| 85 | + actor_rollout_ref.model.enable_gradient_checkpointing=True \ |
| 86 | + actor_rollout_ref.actor.ppo_mini_batch_size="${PPO_MINI_BATCH_SIZE}" \ |
| 87 | + actor_rollout_ref.actor.use_dynamic_bsz=True \ |
| 88 | + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=9000 \ |
| 89 | + actor_rollout_ref.actor.use_kl_loss=False \ |
| 90 | + actor_rollout_ref.actor.clip_ratio_low="${CLIP_RATIO_LOW}" \ |
| 91 | + actor_rollout_ref.actor.clip_ratio_high="${CLIP_RATIO_HIGH}" \ |
| 92 | + actor_rollout_ref.actor.clip_ratio_c=10.0 \ |
| 93 | + actor_rollout_ref.actor.use_jackpot=True \ |
| 94 | + actor_rollout_ref.actor.jackpot_use_latest_logits=True \ |
| 95 | + actor_rollout_ref.actor.jackpot_log_probs_to_keep="${JACKPOT_LOGPROBS_TO_KEEP}" \ |
| 96 | + actor_rollout_ref.actor.jackpot_lambda="${JACKPOT_LAMBDA}" \ |
| 97 | + actor_rollout_ref.actor.jackpot_clip_ratio="${JACKPOT_CLIP_RATIO}" \ |
| 98 | + actor_rollout_ref.actor.jackpot_use_topk_renorm=True \ |
| 99 | + actor_rollout_ref.actor.entropy_coeff=0 \ |
| 100 | + actor_rollout_ref.actor.fsdp_config.param_offload=True \ |
| 101 | + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ |
| 102 | + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ |
| 103 | + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ |
| 104 | + actor_rollout_ref.rollout.name=vllm \ |
| 105 | + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ |
| 106 | + actor_rollout_ref.rollout.n="${ROLLOUT_N}" \ |
| 107 | + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=30000 \ |
| 108 | + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ |
| 109 | + actor_rollout_ref.rollout.free_cache_engine=True \ |
| 110 | + actor_rollout_ref.rollout.mode=sync \ |
| 111 | + actor_rollout_ref.rollout.calculate_log_probs=True \ |
| 112 | + actor_rollout_ref.rollout.log_probs_to_keep="${JACKPOT_LOGPROBS_TO_KEEP}" \ |
| 113 | + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=30000 \ |
| 114 | + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ |
| 115 | + actor_rollout_ref.ref.fsdp_config.param_offload=True \ |
| 116 | + actor_rollout_ref.ref.fsdp_config.strategy=fsdp2 \ |
| 117 | + algorithm.adv_estimator=grpo \ |
| 118 | + algorithm.use_kl_in_reward=False \ |
| 119 | + algorithm.filter_groups.enable="${ENABLE_FILTER_GROUPS}" \ |
| 120 | + algorithm.filter_groups.max_num_gen_batches="${MAX_NUM_GEN_BATCHES}" \ |
| 121 | + algorithm.filter_groups.metric="${FILTER_GROUPS_METRIC}" \ |
| 122 | + reward_model.reward_manager=dapo \ |
| 123 | + reward_model.overlong_buffer.enable="${ENABLE_OVERLONG_BUFFER}" \ |
| 124 | + reward_model.overlong_buffer.len="${OVERLONG_BUFFER_LEN}" \ |
| 125 | + reward_model.overlong_buffer.penalty_factor="${OVERLONG_PENALTY_FACTOR}" \ |
| 126 | + "+trainer.topologies=[{name:dual_kl,rollout:small,logprob:small,train:[small,large],logprob_map:{large:large},kl_pairs:[{name:fwd_small_vs_large,train:small,p:large,q:small,mode:forward,coef:${FWD_KL_SMALL}},{name:rev_large_vs_small,train:large,p:large,q:small,mode:reverse,coef:${REV_KL_LARGE},use_is:true}]}]" \ |
| 127 | + '+trainer.topology_loop=[{name:dual_kl,repeat:1}]' \ |
| 128 | + "+trainer.worker_namespaces=[{name:small,train:true,config:{trainer:{rollout_from:small},actor_rollout_ref:{model:{path:'${SMALL_MODEL}'},actor:{optim:{lr:1e-6},ppo_mini_batch_size:${PPO_MINI_BATCH_SIZE},use_dynamic_bsz:true,ppo_max_token_len_per_gpu:12000,use_jackpot:true,jackpot_use_latest_logits:true,jackpot_log_probs_to_keep:${JACKPOT_LOGPROBS_TO_KEEP},jackpot_lambda:${JACKPOT_LAMBDA},jackpot_clip_ratio:${JACKPOT_CLIP_RATIO},jackpot_use_topk_renorm:true,fsdp_config:{param_offload:true,optimizer_offload:true}},ref:{log_prob_max_token_len_per_gpu:30000},rollout:{tensor_model_parallel_size:1,gpu_memory_utilization:0.6,log_prob_max_token_len_per_gpu:30000,log_prob_use_dynamic_bsz:true,free_cache_engine:true,mode:sync,calculate_log_probs:true,log_probs_to_keep:${JACKPOT_LOGPROBS_TO_KEEP}}}}}]" \ |
| 129 | + "$@" |
0 commit comments