Skip to content

Commit 84aceb8

Browse files
kstaniszewsknvcoderabbitai[bot]
authored andcommitted
Add Dynamic Memory Sparsification (DMS) training and inference implementation (#877)
**Type of change:** new feature **Overview:** Training and inference code for Dynamic Memory Sparsification (DMS) - method from NeurIPS 2025 paper [Inference-Time Hyper-Scaling with KV Cache Compression](https://neurips.cc/virtual/2025/loc/san-diego/poster/119605) Detailed in `experimental/dms/README.md` and `experimental/dms/ARCHITECTURE.md` DMS tests in `experimental/dms/tests` covering: * prefill * generation * gradient propagation * chunked prefill <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No, DMS is currently experimental feature with description in `experimental/dms` A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in: > **Inference-Time Hyper-Scaling with KV Cache Compression** > Adrian Łańcucki, Konrad Staniszewski, Piotr Nawrot, Edoardo M. Ponti > Paper: [https://arxiv.org/abs/2506.05345](https://arxiv.org/abs/2506.05345) > NeurIPS: [https://neurips.cc/virtual/2025/loc/san-diego/poster/119605](https://neurips.cc/virtual/2025/loc/san-diego/poster/119605) Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy. DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Introduces Dynamic Memory Sparsification (DMS), an algorithm for efficient LLM inference and training with adaptive attention gating. * Adds DMS-enabled Qwen3 models with memory-efficient KV cache management and paged block-based storage. * Includes student-teacher distillation training infrastructure with noise scheduling and compression ratio control. * Provides configuration system and training/evaluation scripts for DMS adaptation. * **Documentation** * Added architecture guide, README, and example inference notebook. * **Tests** * Added comprehensive test suite for chunked prefill, cache management, and prefill/inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com> Signed-off-by: kstaniszewsknv <kstaniszewsk@nvidia.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 3d61944 commit 84aceb8

31 files changed

+6824
-0
lines changed

experimental/dms/ARCHITECTURE.md

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# DMS Architecture and Advanced Options
2+
3+
This document describes DMS internals, configuration options, and how to extend the codebase.
4+
5+
## Code Details
6+
7+
### Eviction Decisions
8+
9+
DMS supports two ways to compute the eviction decision:
10+
11+
- **Extracted from a single neuron of a key or query vector**: see Section 3.1 of [Dynamic Memory Compression: Retrofitting LLMs for Accelerated Inference](https://arxiv.org/pdf/2403.09636). Enable with `dms_separate_alpha=False`.
12+
- **Produced by a learned linear projection (adapter) from the hidden state**: see Section 3.2 of [Inference-Time Hyper-Scaling with KV Cache Compression](https://arxiv.org/pdf/2506.05345). Enable with `dms_separate_alpha=True`.
13+
14+
You can also choose the granularity of eviction decisions:
15+
16+
- `dms_alpha_per: "head"`: decisions are made independently per attention head (KV cache lengths may differ across heads).
17+
- `dms_alpha_per: "layer"`: decisions are shared across heads within a layer (all heads in the layer keep the same number of tokens).
18+
19+
During training, decision logits are augmented with Gumbel noise to enable differentiable gating (`dms.core.get_gating_with_noise`). During inference, a hard threshold is used.
20+
21+
### Attention
22+
23+
The DMS attention implementation (given decision logits) can be found in `dms/attention.py` (see `dms_attn_train_mode`).
24+
25+
### Loss Function
26+
27+
Training uses knowledge distillation with forward KL divergence between student and teacher logits, computed in `dms/training/engine.py` (`distillation_loss`). This is combined with a DMS compression loss that encourages the model to match the target eviction fraction.
28+
29+
### DMS Schedule
30+
31+
The compression ratio increases linearly from `initial_cr` (typically 1.0) to `final_cr` (e.g., 16.0) over `final_step` training steps. See `dms_schedule()` in `dms/training/engine.py`.
32+
33+
## Advanced Options
34+
35+
### Chunked Prefill
36+
37+
Chunked prefill reduces peak memory usage during the prefill phase by processing the input sequence in fixed-size chunks. Set the chunk size (in tokens) via:
38+
39+
```python
40+
Qwen3ForCausalLMDMS.from_pretrained(..., dms_chunked_prefill=4096)
41+
```
42+
43+
### Cache Preallocation
44+
45+
The paged KV cache uses a dynamically resizable per-attention-layer block table (similar to `std::vector` in C++), growing as needed during generation. If you know your maximum context length ahead of time, you can preallocate to avoid runtime reallocations:
46+
47+
```python
48+
Qwen3ForCausalLMDMS.from_pretrained(..., dms_preallocate_for_tokens=2048)
49+
```
50+
51+
## Retrofitting a New Model Family
52+
53+
To add DMS support for a new model family, create a new directory under `models/`:
54+
55+
```bash
56+
models/new_model/
57+
├── configuration_new_model_dms.py # Config extending the base model config
58+
├── extract.py # Checkpoint extraction
59+
├── modeling_new_model_dms.py # Model with DMS attention
60+
└── train.py # Training entry point
61+
```
62+
63+
The model-specific code should:
64+
65+
1. Extend the model's config class with DMS parameters (see `models/qwen3/configuration_qwen3_dms.py`).
66+
2. Override the attention forward pass and call:
67+
- `dms.core.prepare_attention_input`
68+
- `dms.attention.dms_attention`
69+
3. Add `dms_proj_alpha` and `dms_proj_alpha_norm` layers to the attention layer.
70+
4. Add a YAML config under `configs/`.
71+
72+
Core DMS operations (`prepare_attention_input`, `dms_attention`, `post_process_attention_output`) are model-agnostic; model-specific code provides its Q/K/V projections and any required norms as inputs.
73+
74+
## Adding a New Dataset
75+
76+
To add a new training dataset, edit `dms/training/data.py`:
77+
78+
1. Define `filter_fn` and `extract_fn` for your dataset.
79+
2. Create a `DatasetInfo` instance.
80+
81+
Example:
82+
83+
```python
84+
def my_dataset_filter_fn(ds_elem):
85+
return ds_elem["quality_score"] > 0.8
86+
87+
def my_dataset_extract_fn(ds_elem):
88+
return {
89+
"conversation": [
90+
{"role": "user", "content": ds_elem["prompt"]},
91+
{"role": "assistant", "content": ds_elem["response"]},
92+
]
93+
}
94+
95+
MyNewDataset = DatasetInfo(
96+
args=("org/my-dataset",),
97+
kwargs={"split": "train"},
98+
filter_fn=my_dataset_filter_fn,
99+
extract_fn=my_dataset_extract_fn,
100+
)
101+
```
102+
103+
Then reference it in your YAML config:
104+
105+
```yaml
106+
data:
107+
blend: "MyNewDataset:0.5,OpenR1Math220k:0.5"
108+
```
109+
110+
## Checkpoint Resume
111+
112+
To resume training from the latest checkpoint, set the following in your YAML config:
113+
114+
```yaml
115+
hf_trainer:
116+
resume_from_checkpoint: "auto"
117+
```
118+
119+
This auto-detects the latest `checkpoint-N` directory under the output directory. You can also specify an explicit path:
120+
121+
```yaml
122+
hf_trainer:
123+
resume_from_checkpoint: outputs/qwen3_8b/checkpoint-300
124+
```
125+
126+
Resume works because:
127+
128+
- The Hugging Face Trainer restores optimizer state, LR scheduler state, the training step counter, and RNG states.
129+
- The DMS schedule is deterministic given the current training step.
130+
- Gumbel noise is seeded from `step + process_index + grad_acc_step`.

experimental/dms/README.md

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Dynamic Memory Sparsification (DMS)
2+
3+
A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in:
4+
5+
> **Inference-Time Hyper-Scaling with KV Cache Compression**
6+
> Adrian Łańcucki, Konrad Staniszewski, Piotr Nawrot, Edoardo M. Ponti
7+
> Paper: [https://arxiv.org/abs/2506.05345](https://arxiv.org/abs/2506.05345)
8+
> NeurIPS: [https://neurips.cc/virtual/2025/loc/san-diego/poster/119605](https://neurips.cc/virtual/2025/loc/san-diego/poster/119605)
9+
10+
Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy.
11+
12+
## How it works
13+
14+
DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model.
15+
16+
## What makes DMS practical
17+
18+
- Achieves **8× compression** with minimal accuracy loss
19+
- Adapter training: the default recipe trains eviction adapters only and freezes base weights for efficiency
20+
- Requires **~250 training steps** (about **4 hours on 8× H100**) to adapt Qwen3-8B
21+
- Drop-in replacement for Hugging Face models via a custom cache that supports variable sequence lengths across attention heads
22+
23+
| Model family | Size | Training time (8× H100) |
24+
|------------|------|--------------------------|
25+
| Qwen3 | 8B | ~4 hours |
26+
27+
---
28+
29+
## Quick start: Retrofitting Qwen3-8B with DMS
30+
31+
### Installation
32+
33+
This repository is designed to run inside an NVIDIA PyTorch container:
34+
35+
```bash
36+
docker pull nvcr.io/nvidia/pytorch:25.11-py3
37+
```
38+
39+
Clone and install:
40+
41+
```bash
42+
git clone https://github.com/NVIDIA/Model-Optimizer
43+
cd experimental/dms
44+
pip install -e .
45+
```
46+
47+
This single install provides everything needed for training and evaluation (including lm-eval-harness).
48+
49+
### Train DMS adapters
50+
51+
**Note:** The number of GPUs determines the effective batch size. The configuration below was tested on a DGX H100 with 8× H100 80GB GPUs. For debugging with a smaller compute budget (e.g., a single RTX 5090), see [`scripts/train_small_debug.sh`](scripts/train_small_debug.sh).
52+
53+
```bash
54+
bash scripts/train.sh configs/qwen3_8b.yaml
55+
```
56+
57+
This freezes the original Qwen3-8B weights and trains only the DMS eviction-policy parameters using knowledge distillation. Training completes in ~4 hours on a single DGX H100 node.
58+
59+
The trained student model is saved to `outputs/qwen3_8b/student_model/` at the end of training.
60+
61+
To resume training from the latest checkpoint, set `resume_from_checkpoint: "auto"` in the YAML config.
62+
63+
### Extract from an intermediate checkpoint (optional)
64+
65+
To extract a model from an intermediate checkpoint, run:
66+
67+
```bash
68+
python -m models.qwen3.extract \
69+
--config outputs/qwen3_8b/config.yaml \
70+
--checkpoint outputs/qwen3_8b/checkpoint-238
71+
```
72+
73+
### Evaluate
74+
75+
Evaluate on the RULER long-context benchmark:
76+
77+
```bash
78+
bash scripts/evaluate.sh outputs/qwen3_8b/student_model
79+
```
80+
81+
**Prerequisite:** The saved model relies on the `dms` package for its attention and cache implementations. Ensure `dms` is installed (`pip install -e .`) in any environment where you load the model for inference or evaluation.
82+
83+
---
84+
85+
## Repository structure
86+
87+
```bash
88+
.
89+
├── configs # YAML experiment configs
90+
│   └── qwen3_8b.yaml
91+
├── dms # Core DMS library (pip install -e .)
92+
│   ├── attention_prefill.py # Exact prefill with eviction-based masking
93+
│   ├── attention.py # DMS attention: train + inference modes
94+
│   ├── cache_paged.py # Paged cache with block-based memory management
95+
│   ├── cache.py # KV cache: HF wrapper + combined + contiguous
96+
│   ├── core.py # Shared ops: prepare_attention_input, gating, chunked prefill
97+
│   └── training
98+
│   ├── data.py # Data pipeline: loading, blending, tokenization
99+
│   └── engine.py # Distillation, model config, noise, trainer state
100+
├── ARCHITECTURE.md
101+
├── example_inference.ipynb
102+
├── models # Model-specific adaptations
103+
│   └── qwen3
104+
│   ├── configuration_qwen3_dms.py # Qwen3ConfigDMS
105+
│   ├── extract.py # Checkpoint extraction
106+
│   ├── modeling_qwen3_dms.py # Qwen3ForCausalLMDMS
107+
│   └── train.py # Training entry point
108+
└── scripts # Launch scripts
109+
   ├── evaluate.sh
110+
   └── train.sh
111+
```
112+
113+
For code details, advanced options, and guides on extending DMS, see [ARCHITECTURE.md](ARCHITECTURE.md).
114+
115+
## Limitations
116+
117+
This repository currently supports training eviction adapters only and keeps base model weights frozen. This training approach can achieve comparable accuracy while being roughly two orders of magnitude cheaper than full fine-tuning. In contrast, the original recipe used in the paper updates all model weights during training; we plan to support it in the near future.
118+
119+
For inference, this repository currently supports a single prefill-then-generate workflow. Multi-turn conversations with interleaved `prefill, generate, prefill, ...` steps are not yet optimized: the cache must be reset between independent sequences, and a slow fallback is used that simulates generation via repeated prefill steps. See [example_inference.ipynb](./example_inference.ipynb) for details.
120+
121+
## Citation
122+
123+
If you found DMS useful, please cite:
124+
125+
```bibtex
126+
@inproceedings{
127+
lancucki2025inferencetime,
128+
title={Inference-Time Hyper-Scaling with {KV} Cache Compression},
129+
author={Adrian {\L}a{\'n}cucki and Konrad Staniszewski and Piotr Nawrot and Edoardo Ponti},
130+
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
131+
year={2025},
132+
url={https://openreview.net/forum?id=8ZiElzQxf1}
133+
}
134+
```
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# DMS debug configuration for Qwen3-1.7B
2+
# Designed for debugging and code optimization on a limited compute budget.
3+
4+
model:
5+
name: Qwen/Qwen3-1.7B
6+
dtype: float32
7+
forward_fn_kwargs:
8+
train_attn_kwargs:
9+
kernel_options:
10+
BLOCK_M1: 16
11+
BLOCK_M2: 16
12+
BLOCK_N1: 16
13+
BLOCK_N2: 16
14+
15+
dms:
16+
alpha_scale: 100.0
17+
initial_alpha_offset: 5.0
18+
window_size: 512
19+
disable_eviction: false
20+
separate_alpha: true
21+
alpha_per: head
22+
tau: 0.1
23+
initial_cr: 1.0
24+
final_cr: 16.0
25+
final_step: 510
26+
27+
data:
28+
blend: "OpenR1Math220k:1.0"
29+
train_samples: 4000
30+
max_length: 8192
31+
concat_always_start_new: true
32+
process_vocab_using_chunk: 4096
33+
tokenizer_kwargs:
34+
enable_thinking: true
35+
36+
hf_trainer:
37+
output_dir: outputs/qwen3_1.7b_small
38+
run_name: dms_qwen3_1.7b_small
39+
max_steps: 544
40+
per_device_train_batch_size: 1
41+
gradient_accumulation_steps: 16
42+
learning_rate: 3.0e-5
43+
weight_decay: 0.0
44+
warmup_steps: 0
45+
lr_scheduler_type: constant
46+
save_strategy: steps
47+
save_steps: 34
48+
save_total_limit: 5
49+
logging_strategy: steps
50+
logging_steps: 1
51+
gradient_checkpointing: false
52+
tf32: false
53+
bf16: true
54+
save_safetensors: false
55+
adam_beta1: 0.9
56+
adam_beta2: 0.95
57+
max_grad_norm: 1.0
58+
seed: 42
59+
fsdp: "full_shard offload"
60+
fsdp_config:
61+
use_orig_params: true
62+
sync_module_states: true
63+
activation_checkpointing: true
64+
resume_from_checkpoint: # null = fresh start, "auto" = latest, or explicit path
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# DMS training configuration for Qwen3-8B
2+
#
3+
# Usage:
4+
# accelerate launch -m models.qwen3.train --config configs/qwen3_8b.yaml
5+
#
6+
# To resume from latest checkpoint:
7+
# Set resume_from_checkpoint to "auto" below, or pass an explicit path.
8+
9+
model:
10+
name: Qwen/Qwen3-8B
11+
dtype: float32
12+
13+
dms:
14+
alpha_scale: 100.0
15+
initial_alpha_offset: 5.0
16+
window_size: 512
17+
disable_eviction: false
18+
separate_alpha: true
19+
alpha_per: head
20+
tau: 0.1
21+
initial_cr: 1.0
22+
final_cr: 16.0
23+
final_step: 510
24+
25+
data:
26+
blend: "OpenR1Math220k:1.0"
27+
train_samples: 4000
28+
max_length: 32768
29+
concat_always_start_new: true
30+
process_vocab_using_chunk: 4096
31+
tokenizer_kwargs:
32+
enable_thinking: true
33+
34+
hf_trainer:
35+
output_dir: outputs/qwen3_8b
36+
run_name: dms_qwen3_8b
37+
max_steps: 544
38+
per_device_train_batch_size: 1
39+
gradient_accumulation_steps: 1
40+
learning_rate: 3.0e-5
41+
weight_decay: 0.0
42+
warmup_steps: 0
43+
lr_scheduler_type: constant
44+
save_strategy: steps
45+
save_steps: 34
46+
save_total_limit: 5
47+
logging_strategy: steps
48+
logging_steps: 1
49+
gradient_checkpointing: false
50+
tf32: false
51+
bf16: true
52+
save_safetensors: false
53+
adam_beta1: 0.9
54+
adam_beta2: 0.95
55+
max_grad_norm: 1.0
56+
seed: 42
57+
fsdp: "full_shard offload"
58+
fsdp_config:
59+
use_orig_params: true
60+
sync_module_states: true
61+
activation_checkpointing: true
62+
resume_from_checkpoint: # null = fresh start, "auto" = latest, or explicit path

0 commit comments

Comments
 (0)