Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions ci/gpu_ci_run_skyrl_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,3 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend
# echo "Skipping integrations tests. Failed to execute uv add command"
# echo "$add_integrations"
# fi

# TODO (sumanthrh): Migrate flashrl to vllm 0.16.0 and re-enable integration test
# Run tests for vllm 0.9.2
# TODO (sumanthrh): We should have a better way to override without pinning a flash-attn wheel
# uv run --isolated --extra fsdp --extra dev \
# --with vllm==0.9.2 \
# --with transformers==4.53.0 \
# --with torch==2.7.0 \
# --with "flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl" \
# -- pytest -s -vvv tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py::test_token_based_generation
1 change: 1 addition & 0 deletions docs/content/docs/examples/meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"search",
"geometry3k",
"visgym",
"quantized_rollouts",
"mini_swe_agent",
"openenv"
]
Expand Down
80 changes: 80 additions & 0 deletions docs/content/docs/examples/quantized_rollouts.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
---
title: "Training with Quantized Rollouts"
---

<Callout type="info">
Quantized rollouts are supported with the `fsdp` and `megatron` backends.
</Callout>

In this example, we walk through how to train a model with quantized (FP8) rollouts using SkyRL. Quantizing the rollout weights speeds up generation for large models while keeping the trainer in full precision.

## Overview

During RL training, rollouts (generation) are typically the throughput bottleneck. Running the inference engine with quantized weights (e.g. FP8) can significantly speed up generation for larger models. SkyRL supports this directly through vLLM — there is no separate patched engine or extra to install.

There are two pieces to make quantized rollouts work well:

1. **Quantized generation** — we ask vLLM to load and serve the policy weights in a quantized format. This is a pass-through option to the vLLM engine, so no SkyRL-specific integration is required.
2. **Off-policy correction** — quantizing the rollout weights widens the gap between the rollout (inference) distribution and the training distribution. We correct for this mismatch with [Truncated Importance Sampling (TIS)](../algorithms/off_policy_correction), which applies an importance-sampling correction to the policy loss.

### How does it work?

We sample generations from the inference engine with quantized weights. We then compute advantages and the policy loss, applying the TIS correction factor to account for the difference between the rollout and training probability distributions. On each weight update, weights are synced to the inference engine layer by layer in half precision (bfloat16); vLLM then quantizes them to the target format before serving.

## Enabling quantized generation

Quantization is enabled by passing it through to the vLLM engine via `engine_init_kwargs`:

```bash
generator.inference_engine.engine_init_kwargs.quantization=fp8
```

This uses vLLM's [online dynamic FP8 quantization](https://docs.vllm.ai/en/latest/features/quantization/fp8.html), so no calibration data or pre-quantized checkpoint is required.

## Enabling off-policy correction (TIS)

To apply TIS, we need the inference engine to return the rollout logprobs for the generated tokens, and we configure the correction on the policy loss:

```bash
# return rollout logprobs for the generated tokens (required for TIS)
generator.sampling_params.logprobs=0 \
# apply sequence-level TIS with an importance-ratio clip
trainer.algorithm.off_policy_correction.tis_ratio_type=sequence \
trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=4.0
Comment thread
erictang000 marked this conversation as resolved.
```

TIS can be applied at the `token` or `sequence` level. See the [Off-Policy Correction guide](../algorithms/off_policy_correction) for a full discussion of the available corrections and recommended settings.

## Example

We provide a complete example that trains `Qwen2.5-Coder-7B-Instruct` on the SkyRL-SQL dataset with FP8 rollouts and a full-precision trainer at `examples/train/text_to_sql/run_skyrl_sql_fp8.sh`.

The key parameters are:

```bash title="examples/train/text_to_sql/run_skyrl_sql_fp8.sh"
# TIS parameters
TIS_IMP_RATIO_CAP=4.0
TIS_TYPE=sequence
# returns rollout logprobs for the generated tokens; required for TIS
LOGPROBS=0

uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
generator.sampling_params.logprobs=$LOGPROBS \
generator.inference_engine.backend=vllm \
generator.inference_engine.engine_init_kwargs.quantization=fp8 \
...
```

To run it (from the SkyRL root directory):

```bash
hf download NovaSky-AI/SkyRL-SQL-653-data-newfmt --local-dir $HOME/data/sql --repo-type dataset
export WANDB_API_KEY=<your_key_here>
bash examples/train/text_to_sql/run_skyrl_sql_fp8.sh
```

<Callout type="warn">
Quantized rollouts are most beneficial for larger models, where generation dominates step time. For smaller models the overhead of quantizing weights during each weight sync can outweigh the generation speedup. We recommend benchmarking on your own model and hardware.
</Callout>
2 changes: 1 addition & 1 deletion docs/lib/sort-tree.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const pageOrder: Record<string, string[]> = {
'getting-started': ['installation', 'quickstart', 'overview', 'development'],
'datasets': ['dataset-preparation'],
'tutorials': ['new_env', 'one_step_off_async', 'fully_async', 'tools_guide', 'skyrl_gym_generator'],
'examples': ['megatron', 'ppo', 'lora', 'llm_as_a_judge', 'remote_server', 'training_backends', 'multi_turn_text2sql', 'search', 'mini_swe_agent', 'openenv'],
'examples': ['megatron', 'ppo', 'lora', 'llm_as_a_judge', 'remote_server', 'training_backends', 'multi_turn_text2sql', 'search', 'quantized_rollouts', 'mini_swe_agent', 'openenv'],
'platforms': ['overview', 'anyscale', 'runpod', 'skypilot'],
'recipes': ['overview', 'skyrl-sql', 'searchr1'],
'algorithms': ['dapo', 'custom_algorithms'],
Expand Down
4 changes: 2 additions & 2 deletions examples/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Welcome to the SkyRL-Train examples! In this folder you can find the following e
- `algorithms/`: Examples for how to configure and run RL with various algorithms and policy-loss variants (e.g., DAPO, SAPO, GRPO, CISPO, GSPO, or your own custom advantage estimators and custom policy losses).
- `ppo/`: Vanilla PPO training (with a critic, ref, and policy model)
- `on_policy_distillation/`: [On-policy distillation recipe](https://novasky-ai.notion.site/on-policy-distillation) that uses a teacher model to provide dense token-level rewards during training, reproducing results from the [Thinking Machines blog](https://thinkingmachines.ai/blog/on-policy-distillation/).
- `tis_correction/`: Applying [Flash-RL TIS](https://fengyao.notion.site/off-policy-rl) correction to improve off-policy stability.
- `tis_correction/`: Applying [Truncated Importance Sampling (TIS)](https://fengyao.notion.site/off-policy-rl) correction to improve off-policy stability.
- `turn_level_rewards/`: GSM8K multi-turn environment illustrating turn-level rewards and custom advantage estimators.

## Async RL
Expand All @@ -20,7 +20,7 @@ Welcome to the SkyRL-Train examples! In this folder you can find the following e
- `llm_as_a_judge/`: GSM8K training with an external LLM as a judge to produce rewards instead of strict exact-match grading.
- `multiply/`: Toy arithmetic environment for multiplying numbers, useful for quick sanity checks and debugging.
- `livecodebench/`: LiveCodeBench code-generation task setup and training scripts.
- `text_to_sql/`: [Text-to-SQL (SkyRL-SQL)](https://docs.skyrl.ai/docs/examples/multi_turn_text2sql) environment and training scripts for mapping natural language questions to SQL queries.
- `text_to_sql/`: [Text-to-SQL (SkyRL-SQL)](https://docs.skyrl.ai/docs/examples/multi_turn_text2sql) environment and training scripts for mapping natural language questions to SQL queries. Includes `run_skyrl_sql_fp8.sh` for [quantized (FP8) rollouts](https://docs.skyrl.ai/docs/examples/quantized_rollouts).
- `step_wise/`: Step-wise training for chat-template agnostic multi-turn RL training.
- `search/`: Multi-turn search agent training with the SearchR1 dataset, backed by a FAISS-based retriever server.

Expand Down
2 changes: 1 addition & 1 deletion integrations/arctic_rl/examples/run_bird_grpo_32b_32gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ set -euxo pipefail
SKYRL_DIR=${SKYRL_DIR:-$(cd "$(dirname "$0")"/../../.. && pwd)}
DATA_DIR=${DATA_DIR:-"$HOME/data/bird"}

# Driver (same shape as flash_rl/harbor; see run_gsm8k_grpo_4gpu.sh for details).
# Driver (same shape as harbor; see run_gsm8k_grpo_4gpu.sh for details).
# FA3 (cp39-abi3 wheel from PyTorch's cu128 index) ships alongside FA2 — this
# is the recipe that produced the 2.38x BIRD-SQL speedup on H200.
FLASH_ATTN_WHL="https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ set -euxo pipefail
SKYRL_DIR=${SKYRL_DIR:-$(cd "$(dirname "$0")"/../../.. && pwd)}
DATA_DIR=${DATA_DIR:-"$HOME/data/bird"}

# Driver (same shape as flash_rl/harbor; see run_gsm8k_grpo_4gpu.sh for details).
# Driver (same shape as harbor; see run_gsm8k_grpo_4gpu.sh for details).
# FSDP variant matches the arctic-rl recipe stack so attention is apples-to-apples.
FLASH_ATTN_WHL="https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl"
FLASH_ATTN3_WHL="https://download.pytorch.org/whl/cu128/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl"
Expand Down
2 changes: 1 addition & 1 deletion integrations/arctic_rl/examples/run_bird_grpo_8b_32gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ set -euxo pipefail
SKYRL_DIR=${SKYRL_DIR:-$(cd "$(dirname "$0")"/../../.. && pwd)}
DATA_DIR=${DATA_DIR:-"$HOME/data/bird"}

# Driver (same shape as flash_rl/harbor; see run_gsm8k_grpo_4gpu.sh for details).
# Driver (same shape as harbor; see run_gsm8k_grpo_4gpu.sh for details).
# FA3 (cp39-abi3 wheel from PyTorch's cu128 index) is the default on Hopper.
FLASH_ATTN_WHL="https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl"
FLASH_ATTN3_WHL="https://download.pytorch.org/whl/cu128/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl"
Expand Down
2 changes: 1 addition & 1 deletion integrations/arctic_rl/examples/run_bird_grpo_smoke.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ set -euxo pipefail
SKYRL_DIR=${SKYRL_DIR:-$(cd "$(dirname "$0")"/../../.. && pwd)}
DATA_DIR=${DATA_DIR:-"$HOME/data/bird"}

# Driver (same shape as flash_rl/harbor; see run_gsm8k_grpo_4gpu.sh for details).
# Driver (same shape as harbor; see run_gsm8k_grpo_4gpu.sh for details).
# Smoke mirrors the BIRD recipe stack — FA3 wheel included so the smoke and
# the full recipe exercise the same attention backend.
FLASH_ATTN_WHL="https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl"
Expand Down
2 changes: 1 addition & 1 deletion integrations/arctic_rl/examples/run_gsm8k_grpo_4gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ if [[ ! -f "${DATA_DIR}/train.parquet" || ! -f "${DATA_DIR}/validation.parquet"
uv run --isolated examples/train/gsm8k/gsm8k_dataset.py --output_dir "${DATA_DIR}"
fi

# Driver (same shape as examples/train/flash_rl + examples/train_integrations/harbor):
# Driver (same shape as examples/train_integrations/harbor):
# - `--isolated` gets a fresh resolution that ignores the project lock /
# vllm-cu129 source pin (arctic-inference needs vLLM 0.18, not 0.23).
# - `--with flash-attn@URL` swaps the torch-2.11 wheel in tool.uv.sources for
Expand Down
13 changes: 1 addition & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ megatron = [
"nvidia-modelopt; sys_platform == 'linux'",
]

flashrl = [
"skyrl[skyrl-train]",
# NOTE: Custom vLLM wheel must be installed separately.
# See examples/train/flash_rl/README.md for installation instructions.
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.7.0; sys_platform == 'linux'",
"flashinfer-python; sys_platform == 'linux'",
"torchvision; sys_platform == 'linux'",
]
miniswe = [
"skyrl[skyrl-train]",
# NOTE (sumanthrh): Needs to be a commit after https://github.com/SWE-agent/mini-swe-agent/commit/4f5d445e99d13b5482478c23508bf2fbf7c0670c
Expand Down Expand Up @@ -207,13 +198,11 @@ conflicts = [
{ extra = "jax" },
{ extra = "megatron" },
{ extra = "fsdp" },
{ extra = "flashrl" },
],
[
{ extra = "megatron" },
{ extra = "gpu" },
{ extra = "tpu" },
{ extra = "flashrl" },
{ extra = "miniswe" },
]
]
Expand Down Expand Up @@ -302,7 +291,7 @@ flash-attn = { url = "https://github.com/erictang000/flash-attention/releases/do
causal-conv1d = { url = "https://github.com/erictang000/causal-conv1d/releases/download/v1.6.1.post4-torch2.11/causal_conv1d-1.6.1-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }
mamba-ssm = { url = "https://github.com/erictang000/mamba/releases/download/v2.3.1-torch2.11/mamba_ssm-2.3.1-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }
# CUDA torch on Linux, CPU torch on macOS (must match skyrl-train).
# Pinned to cu128: flashrl extra needs torch 2.7, only on cu128.
# Linux uses the cu128 index (torch 2.11 wheels are published there).
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
Expand Down
Loading
Loading