Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion docs/source/en/continuous_batching.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,32 @@ manager.start()
# submit and retrieve requests...
```

Call [`ContinuousBatchingManager.stop`] to terminate the manager.
### Shutting down the manager

The manager runs a background thread and holds distributed resources. Shutdown happens in two stages so you can choose what to do with in-flight work.

Call [`~ContinuousBatchingManager.stop`] to halt the background thread. By default, the manager stops accepting new submissions and waits for queued and active requests to finish before the thread exits.

```py
manager.stop()
```

Pass `hard_stop=True` to abandon pending work immediately. Queued and active requests are failed with a `RuntimeError` instead of finishing.

```py
manager.stop(hard_stop=True)
```

Once `stop` is called, [`~ContinuousBatchingManager.add_request`] and [`~ContinuousBatchingManager.add_requests`] drop new submissions and log a warning. You can still call `start` again to run another generation session with the same manager.

Call [`~ContinuousBatchingManager.destroy`] to release distributed resources. `destroy` stops the manager first if it's still running, and the manager cannot be restarted afterwards. Use it when you're done with continuous batching for the lifetime of the process.

```py
manager.destroy()
```

[`~ContinuousMixin.continuous_batching_context_manager`] handles this process. It calls `stop` on exit and `destroy` unless you pass `persistent_manager=True` to cache the manager on the model for the next session.

### Adding requests

[`~ContinuousBatchingManager.add_request`] submits a single request. Provide a `request_id` or let the manager generate one automatically.
Expand Down Expand Up @@ -285,6 +305,16 @@ cb_config = ContinuousBatchingConfig(cpu_offload_space=8.0)

By default, `cpu_offload_space_safety_threshold=0.8` limits the requested space to 80% of available system RAM when `psutil` is installed. Set `cpu_offload_space=None` to size the swap pool from the safety threshold.

### Tensor parallel timeout

Under tensor parallelism, the manager creates a CPU communication group to coordinate request submissions, cancellations, and shutdown across ranks. `cpu_group_timeout` limits how long a collective on this group can block before the process crashes. If one rank stalls, the timeout prevents the others from waiting forever.

Set a longer timeout for workloads that issue infrequent collectives, or pass `None` to disable it.

```py
cb_config = ContinuousBatchingConfig(cpu_group_timeout=600.0)
```

### Prefix caching

When multiple requests share a common prefix, like a system prompt, the manager reuses their KV cache blocks instead of recomputing them. This is enabled by default and requires all model layers to use full attention (it's automatically disabled for sliding window models).
Expand Down Expand Up @@ -314,6 +344,39 @@ model = AutoModelForCausalLM.from_pretrained(
)
```

## Tensor parallelism

For models too large to fit on a single GPU, shard the weights across devices with tensor parallelism. Load the model with `tp_plan="auto"` and continuous batching reads the tensor parallel size from the model to size the paged KV cache per shard. See [Tensor parallelism](./tensor_parallelism) for the list of supported architectures and how sharding works.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import ContinuousBatchingConfig, GenerationConfig

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-32B",
attn_implementation="paged|flash_attention_2",
tp_plan="auto",
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-32B")

inputs = [tokenizer.encode(p) for p in ["Whats up?", "Name a cat breed."]]
generation_config = GenerationConfig(max_new_tokens=64, eos_token_id=tokenizer.eos_token_id)

outputs = model.generate_batch(inputs=inputs, generation_config=generation_config)
```

Launch the script with `torchrun`, setting `--nproc-per-node` to the number of GPUs you want to shard across.

```shell
torchrun --nproc-per-node 4 cb_tp.py
```

The tensor parallel size must divide the model's `num_key_value_heads` (check the model config). The paged cache raises an error at startup otherwise, so choose an appropriate `--nproc-per-node`.

> [!WARNING]
> Don't set `device_map` with `tp_plan`. The two conflict because `device_map` places whole modules on specific GPUs, while `tp_plan` shards those same parameters across all GPUs.

## Sliding window attention

Models with sliding window attention (Mistral, Gemma 2) work with continuous batching. To manually configure a sliding window for fine-tuning or custom experiments, set it in the model config before loading.
Expand Down
Loading