Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ repos:
examples/speculative_decoding/main.py|
examples/speculative_decoding/medusa_utils.py|
examples/speculative_decoding/server_generate.py|
experimental/dms/models/qwen3/configuration_qwen3_dms.py|
experimental/dms/models/qwen3/modeling_qwen3_dms.py|
)$

# Default hook for Apache 2.0 in c/c++/cuda files
Expand Down
130 changes: 130 additions & 0 deletions experimental/dms/ARCHITECTURE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# DMS Architecture and Advanced Options

This document describes DMS internals, configuration options, and how to extend the codebase.

## Code Details

### Eviction Decisions

DMS supports two ways to compute the eviction decision:

- **Extracted from a single neuron of a key or query vector**: see Section 3.1 of [Dynamic Memory Compression: Retrofitting LLMs for Accelerated Inference](https://arxiv.org/pdf/2403.09636). Enable with `dms_separate_alpha=False`.
- **Produced by a learned linear projection (adapter) from the hidden state**: see Section 3.2 of [Inference-Time Hyper-Scaling with KV Cache Compression](https://arxiv.org/pdf/2506.05345). Enable with `dms_separate_alpha=True`.

You can also choose the granularity of eviction decisions:

- `dms_alpha_per: "head"`: decisions are made independently per attention head (KV cache lengths may differ across heads).
- `dms_alpha_per: "layer"`: decisions are shared across heads within a layer (all heads in the layer keep the same number of tokens).

During training, decision logits are augmented with Gumbel noise to enable differentiable gating (`dms.core.get_gating_with_noise`). During inference, a hard threshold is used.

### Attention

The DMS attention implementation (given decision logits) can be found in `dms/attention.py` (see `dms_attn_train_mode`).

### Loss Function

Training uses knowledge distillation with forward KL divergence between student and teacher logits, computed in `dms/training/engine.py` (`distillation_loss`). This is combined with a DMS compression loss that encourages the model to match the target eviction fraction.

### DMS Schedule

The compression ratio increases linearly from `initial_cr` (typically 1.0) to `final_cr` (e.g., 16.0) over `final_step` training steps. See `dms_schedule()` in `dms/training/engine.py`.

## Advanced Options

### Chunked Prefill

Chunked prefill reduces peak memory usage during the prefill phase by processing the input sequence in fixed-size chunks. Set the chunk size (in tokens) via:

```python
Qwen3ForCausalLMDMS.from_pretrained(..., dms_chunked_prefill=4096)
```

### Cache Preallocation

The paged KV cache uses a dynamically resizable per-attention-layer block table (similar to `std::vector` in C++), growing as needed during generation. If you know your maximum context length ahead of time, you can preallocate to avoid runtime reallocations:

```python
Qwen3ForCausalLMDMS.from_pretrained(..., dms_preallocate_for_tokens=2048)
```

## Retrofitting a New Model Family

To add DMS support for a new model family, create a new directory under `models/`:

```bash
models/new_model/
├── configuration_new_model_dms.py # Config extending the base model config
├── extract.py # Checkpoint extraction
├── modeling_new_model_dms.py # Model with DMS attention
└── train.py # Training entry point
```

The model-specific code should:

1. Extend the model's config class with DMS parameters (see `models/qwen3/configuration_qwen3_dms.py`).
2. Override the attention forward pass and call:
- `dms.core.prepare_attention_input`
- `dms.attention.dms_attention`
3. Add `dms_proj_alpha` and `dms_proj_alpha_norm` layers to the attention layer.
4. Add a YAML config under `configs/`.

Core DMS operations (`prepare_attention_input`, `dms_attention`, `post_process_attention_output`) are model-agnostic; model-specific code provides its Q/K/V projections and any required norms as inputs.

## Adding a New Dataset

To add a new training dataset, edit `dms/training/data.py`:

1. Define `filter_fn` and `extract_fn` for your dataset.
2. Create a `DatasetInfo` instance.

Example:

```python
def my_dataset_filter_fn(ds_elem):
return ds_elem["quality_score"] > 0.8

def my_dataset_extract_fn(ds_elem):
return {
"conversation": [
{"role": "user", "content": ds_elem["prompt"]},
{"role": "assistant", "content": ds_elem["response"]},
]
}

MyNewDataset = DatasetInfo(
args=("org/my-dataset",),
kwargs={"split": "train"},
filter_fn=my_dataset_filter_fn,
extract_fn=my_dataset_extract_fn,
)
```

Then reference it in your YAML config:

```yaml
data:
blend: "MyNewDataset:0.5,OpenR1Math220k:0.5"
```

## Checkpoint Resume

To resume training from the latest checkpoint, set the following in your YAML config:

```yaml
hf_trainer:
resume_from_checkpoint: "auto"
```

This auto-detects the latest `checkpoint-N` directory under the output directory. You can also specify an explicit path:

```yaml
hf_trainer:
resume_from_checkpoint: outputs/qwen3_8b/checkpoint-300
```

Resume works because:

- The Hugging Face Trainer restores optimizer state, LR scheduler state, the training step counter, and RNG states.
- The DMS schedule is deterministic given the current training step.
- Gumbel noise is seeded from `step + process_index + grad_acc_step`.
134 changes: 134 additions & 0 deletions experimental/dms/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Dynamic Memory Sparsification (DMS)

A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in:

> **Inference-Time Hyper-Scaling with KV Cache Compression**
> Adrian Łańcucki, Konrad Staniszewski, Piotr Nawrot, Edoardo M. Ponti
> Paper: [https://arxiv.org/abs/2506.05345](https://arxiv.org/abs/2506.05345)
> NeurIPS: [https://neurips.cc/virtual/2025/loc/san-diego/poster/119605](https://neurips.cc/virtual/2025/loc/san-diego/poster/119605)

Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy.

## How it works

DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model.

## What makes DMS practical

- Achieves **8× compression** with minimal accuracy loss
- Adapter training: the default recipe trains eviction adapters only and freezes base weights for efficiency
- Requires **~250 training steps** (about **4 hours on 8× H100**) to adapt Qwen3-8B
- Drop-in replacement for Hugging Face models via a custom cache that supports variable sequence lengths across attention heads

| Model family | Size | Training time (8× H100) |
|------------|------|--------------------------|
| Qwen3 | 8B | ~4 hours |

---

## Quick start: Retrofitting Qwen3-8B with DMS

### Installation

This repository is designed to run inside an NVIDIA PyTorch container:

```bash
docker pull nvcr.io/nvidia/pytorch:25.11-py3
```

Clone and install:

```bash
git clone https://github.com/NVIDIA/Model-Optimizer
cd experimental/dms
pip install -e .
```

This single install provides everything needed for training and evaluation (including lm-eval-harness).

### Train DMS adapters

**Note:** The number of GPUs determines the effective batch size. The configuration below was tested on a DGX H100 with 8× H100 80GB GPUs. For debugging with a smaller compute budget (e.g., a single RTX 5090), see [`scripts/train_small_debug.sh`](scripts/train_small_debug.sh).

```bash
bash scripts/train.sh configs/qwen3_8b.yaml
```

This freezes the original Qwen3-8B weights and trains only the DMS eviction-policy parameters using knowledge distillation. Training completes in ~4 hours on a single DGX H100 node.

The trained student model is saved to `outputs/qwen3_8b/student_model/` at the end of training.

To resume training from the latest checkpoint, set `resume_from_checkpoint: "auto"` in the YAML config.

### Extract from an intermediate checkpoint (optional)

To extract a model from an intermediate checkpoint, run:

```bash
python -m models.qwen3.extract \
--config outputs/qwen3_8b/config.yaml \
--checkpoint outputs/qwen3_8b/checkpoint-238
```

### Evaluate

Evaluate on the RULER long-context benchmark:

```bash
bash scripts/evaluate.sh outputs/qwen3_8b/student_model
```

**Prerequisite:** The saved model relies on the `dms` package for its attention and cache implementations. Ensure `dms` is installed (`pip install -e .`) in any environment where you load the model for inference or evaluation.

---

## Repository structure

```bash
.
├── configs # YAML experiment configs
│   └── qwen3_8b.yaml
├── dms # Core DMS library (pip install -e .)
│   ├── attention_prefill.py # Exact prefill with eviction-based masking
│   ├── attention.py # DMS attention: train + inference modes
│   ├── cache_paged.py # Paged cache with block-based memory management
│   ├── cache.py # KV cache: HF wrapper + combined + contiguous
│   ├── core.py # Shared ops: prepare_attention_input, gating, chunked prefill
│   └── training
│   ├── data.py # Data pipeline: loading, blending, tokenization
│   └── engine.py # Distillation, model config, noise, trainer state
├── ARCHITECTURE.md
├── example_inference.ipynb
├── models # Model-specific adaptations
│   └── qwen3
│   ├── configuration_qwen3_dms.py # Qwen3ConfigDMS
│   ├── extract.py # Checkpoint extraction
│   ├── modeling_qwen3_dms.py # Qwen3ForCausalLMDMS
│   └── train.py # Training entry point
└── scripts # Launch scripts
   ├── evaluate.sh
   └── train.sh
```

For code details, advanced options, and guides on extending DMS, see [ARCHITECTURE.md](ARCHITECTURE.md).

## Limitations

This repository currently supports training eviction adapters only and keeps base model weights frozen. This training approach can achieve comparable accuracy while being roughly two orders of magnitude cheaper than full fine-tuning. In contrast, the original recipe used in the paper updates all model weights during training; we plan to support it in the near future.

For inference, this repository currently supports a single prefill-then-generate workflow. Multi-turn conversations with interleaved `prefill, generate, prefill, ...` steps are not yet optimized: the cache must be reset between independent sequences, and a slow fallback is used that simulates generation via repeated prefill steps. See [example_inference.ipynb](./example_inference.ipynb) for details.

## Citation

If you found DMS useful, please cite:

```bibtex
@inproceedings{
lancucki2025inferencetime,
title={Inference-Time Hyper-Scaling with {KV} Cache Compression},
author={Adrian {\L}a{\'n}cucki and Konrad Staniszewski and Piotr Nawrot and Edoardo Ponti},
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
year={2025},
url={https://openreview.net/forum?id=8ZiElzQxf1}
}
```
64 changes: 64 additions & 0 deletions experimental/dms/configs/qwen3_1.7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# DMS debug configuration for Qwen3-1.7B
# Designed for debugging and code optimization on a limited compute budget.

model:
name: Qwen/Qwen3-1.7B
dtype: float32
forward_fn_kwargs:
train_attn_kwargs:
kernel_options:
BLOCK_M1: 16
BLOCK_M2: 16
BLOCK_N1: 16
BLOCK_N2: 16

dms:
alpha_scale: 100.0
initial_alpha_offset: 5.0
window_size: 512
disable_eviction: false
separate_alpha: true
alpha_per: head
tau: 0.1
initial_cr: 1.0
final_cr: 16.0
final_step: 510

data:
blend: "OpenR1Math220k:1.0"
train_samples: 4000
max_length: 8192
concat_always_start_new: true
process_vocab_using_chunk: 4096
tokenizer_kwargs:
enable_thinking: true

hf_trainer:
output_dir: outputs/qwen3_1.7b_small
run_name: dms_qwen3_1.7b_small
max_steps: 544
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 3.0e-5
weight_decay: 0.0
warmup_steps: 0
lr_scheduler_type: constant
save_strategy: steps
save_steps: 34
save_total_limit: 5
logging_strategy: steps
logging_steps: 1
gradient_checkpointing: false
tf32: false
bf16: true
save_safetensors: false
adam_beta1: 0.9
adam_beta2: 0.95
max_grad_norm: 1.0
seed: 42
fsdp: "full_shard offload"
fsdp_config:
use_orig_params: true
sync_module_states: true
activation_checkpointing: true
resume_from_checkpoint: # null = fresh start, "auto" = latest, or explicit path
Loading