Skip to content

Commit e748f73

Browse files
committed
SpecDec Bench: February Update
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent 5e43b2a commit e748f73

31 files changed

+2420
-200
lines changed

examples/specdec_bench/README.md

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
## Installation
44

55
This benchmark is meant to be a lightweight layer ontop of an existing vLLM/SGLang/TRTLLM installation. For example, no install
6-
is required if one is running in the following dockers: `vllm/vllm-openai:v0.11.0` (vLLM), `lmsysorg/sglang:v0.5.4.post2` (SGLang), or
7-
`nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc4` (TRT-LLM).
6+
is required if one is running in the following dockers: `vllm/vllm-openai:v0.15.0` (vLLM), `lmsysorg/sglang:v0.5.7` (SGLang), or
7+
`nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc2` (TRT-LLM).
88

99
Next
1010

@@ -41,9 +41,52 @@ python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b -
4141

4242
```
4343

44+
### Running [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) on Llama 3.3 70B + Eagle 3
45+
46+
47+
1. Install the requirements file using `pip install -r requirements.txt`
48+
49+
2. Prepare the data using the provided script:
50+
```bash
51+
python3 prepare_data.py --dataset speed --config all
52+
```
53+
54+
The data will be saved to `data/` directory, each config type (qualitative, throughput_1k, ...) to each own directory.
55+
56+
#### License
57+
GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.
58+
59+
ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1.
60+
61+
NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The `prepare_data.py` script automatically fetches data from all the source datasets.
62+
63+
Additional details are in [HuggingFace dataset repository](https://huggingface.co/datasets/nvidia/SPEED-Bench).
64+
65+
66+
#### Qualitative split:
67+
```bash
68+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/qualitative --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress
69+
```
70+
71+
#### Throughput split:
72+
```bash
73+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_1k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress
74+
```
75+
76+
For longer context (>8192 tokens), please use the following configuration when using TRTLLM:
77+
```yaml
78+
engine_args:
79+
max_seq_len: 131072 # Model max context length (for Llama 3.3 70B)
80+
enable_chunked_prefill: true
81+
```
82+
83+
```bash
84+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml
85+
```
86+
4487
## Notes
4588

4689
The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
4790
This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics.
4891
If larger concurrency is needed, it is recommended to fully deploy the model using `vllm serve`, `python -m sglang.launch_server`, or `trtllm-serve` (for vLLM, SGlang, or TRTLLM respectively) and
49-
use a more robust benchmarking client like NVIDIA AI Perf.
92+
use a more robust benchmarking client like NVIDIA AI Perf.
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Porting Spec-Bench Inference Runners to specdec_bench
2+
3+
This guide explains how to convert any `inference_*.py` runner from [Spec-Bench](https://github.com/hemingkx/Spec-Bench) to a model class compatible with `specdec_bench`.
4+
5+
## Overview
6+
7+
Spec-Bench inference runners follow a pattern where:
8+
1. A `*_forward()` function handles the speculative decoding logic
9+
2. The `run_eval()` function orchestrates evaluation with tokenized inputs
10+
3. Models are loaded in `__main__` and passed to `run_eval()`
11+
12+
In contrast, `specdec_bench` uses a class-based approach where:
13+
1. Models inherit from the `Model` base class
14+
2. `__init__()` handles model loading
15+
3. `run()` is an async method that processes single requests
16+
4. `stop()` handles cleanup
17+
18+
## The specdec_bench Model Interface
19+
20+
```python
21+
class Model:
22+
def __init__(self, model_dir, tokenizer, max_draft_length):
23+
raise NotImplementedError
24+
25+
async def run(self, prompt_ids, sampling_params, request_id, turn_id):
26+
"""
27+
prompt_ids: list of token IDs (not a tensor!)
28+
Returns dict with:
29+
- output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]]
30+
- output_logits: optional logits (usually None)
31+
- token_times: list of timestamps per decoding step
32+
"""
33+
raise NotImplementedError
34+
35+
def stop(self):
36+
pass
37+
```
38+
39+
## Step-by-Step Porting Guide
40+
41+
### Step 1: Identify the Key Components in Spec-Bench
42+
43+
Look at the `inference_*.py` file and identify:
44+
45+
1. **The forward function** (e.g., `medusa_forward`, `ea_forward`)
46+
- This contains the core speculative decoding loop
47+
- Signature: `forward_func(inputs, model, tokenizer, max_new_tokens, **kwargs)`
48+
- Returns: `(output_ids, new_token_count, num_steps, accept_length_list)`
49+
50+
2. **The model class** (e.g., `MedusaModel`, `EaModel`)
51+
- Found in `model/<method>/` directory
52+
- Has a `from_pretrained()` class method
53+
54+
3. **Required utilities** from the method's module:
55+
- Buffer generation (e.g., `generate_medusa_buffers`)
56+
- Initialization functions (e.g., `initialize_medusa`, `initialize_past_key_values`)
57+
- Decoding functions (e.g., `tree_decoding`, `generate_candidates`)
58+
- State update functions (e.g., `update_inference_inputs`)
59+
60+
4. **Method-specific choices/configs** (e.g., `mc_sim_7b_63` for Medusa)
61+
62+
### Step 2: Create the specdec_bench Model Class
63+
64+
```python
65+
# specdec_bench/specdec_bench/models/specbench_<method>.py
66+
67+
from .base import Model
68+
import asyncio
69+
import time
70+
import torch
71+
72+
# Import dependencies from Spec-Bench
73+
try:
74+
import sys
75+
import os
76+
spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench")
77+
sys.path.insert(0, spec_bench_path)
78+
from model.<method>.<model_file> import <ModelClass>
79+
from model.<method>.kv_cache import initialize_past_key_values
80+
from model.<method>.utils import (
81+
# Import all required utilities
82+
)
83+
from model.<method>.<choices_file> import <default_choices>
84+
except ImportError as e:
85+
print(f"<Method> dependencies not found: {e}")
86+
<ModelClass> = None
87+
88+
89+
class SpecBench<Method>Model(Model):
90+
def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs):
91+
# 1. Validate dependencies
92+
if <ModelClass> is None:
93+
raise ImportError("<Method> dependencies not found.")
94+
95+
# 2. Extract configuration from kwargs
96+
self.dtype = kwargs.get("dtype", "float16")
97+
self.max_steps = kwargs.get("max_steps", 512)
98+
self.temperature = sampling_kwargs.get("temperature", 0.0)
99+
# ... other method-specific parameters
100+
101+
# 3. Set up device (avoid device_map="auto" for multi-GPU issues)
102+
self.device = torch.device(kwargs.get("device", "cuda:0"))
103+
104+
# 4. Convert dtype string to torch dtype
105+
dtype_map = {
106+
"float32": torch.float32,
107+
"float16": torch.float16,
108+
"bfloat16": torch.bfloat16,
109+
}
110+
torch_dtype = dtype_map.get(self.dtype, torch.float16)
111+
112+
# 5. Load the model
113+
self.model = <ModelClass>.from_pretrained(
114+
model_dir,
115+
# ... other args from Spec-Bench's __main__
116+
torch_dtype=torch_dtype,
117+
low_cpu_mem_usage=True,
118+
)
119+
self.model = self.model.to(self.device)
120+
121+
self.sampling_kwargs = sampling_kwargs
122+
```
123+
124+
### Step 3: Port the Forward Function
125+
126+
Convert the standalone `*_forward()` function to an internal method:
127+
128+
```python
129+
def _forward(self, input_ids, max_new_tokens, end_id):
130+
"""
131+
Port of the original *_forward function.
132+
133+
Key changes from Spec-Bench:
134+
1. input_ids is already a tensor (converted in run())
135+
2. Add timing list to track per-step timestamps
136+
3. Use self.device instead of model.base_model.device
137+
4. Return timing along with other outputs
138+
"""
139+
accept_length_list = []
140+
timing = [time.perf_counter()] # ADD: Track timing
141+
142+
# === COPY THE FORWARD LOGIC FROM SPEC-BENCH ===
143+
# Replace: device=model.base_model.device
144+
# With: device=self.device
145+
146+
# Initialize buffers...
147+
# Initialize KV cache...
148+
# Main decoding loop...
149+
150+
for idx in range(self.max_steps):
151+
# Generate candidates...
152+
# Tree decoding...
153+
# Evaluate posterior...
154+
# Update inputs...
155+
156+
timing.append(time.perf_counter()) # ADD: Record time per step
157+
158+
# Check for EOS
159+
if end_id in input_ids[0, input_len:].tolist():
160+
break
161+
if new_token > max_new_tokens:
162+
break
163+
164+
return input_ids, new_token, idx + 1, accept_length_list, timing # ADD timing
165+
```
166+
167+
### Step 4: Implement the run() Method
168+
169+
```python
170+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
171+
"""
172+
Async interface for specdec_bench.
173+
174+
Args:
175+
prompt_ids: List of input token IDs (NOT a tensor)
176+
max_length: Maximum new tokens to generate
177+
end_id: EOS token ID
178+
request_id: Request identifier
179+
turn_id: Turn identifier
180+
181+
Returns:
182+
dict with output_ids, output_logits, token_times
183+
"""
184+
output_dict = {}
185+
186+
# Convert prompt_ids list to tensor
187+
input_ids = torch.tensor(
188+
[prompt_ids], dtype=torch.long, device=self.device
189+
)
190+
191+
# Run forward pass (use asyncio.to_thread for sync code)
192+
result = await asyncio.to_thread(
193+
self._forward, input_ids, max_length, end_id
194+
)
195+
input_ids_out, new_token, num_steps, accept_length_list, timing = result
196+
197+
# Extract generated tokens (excluding prompt)
198+
original_len = len(prompt_ids)
199+
generated_tokens = input_ids_out[0, original_len:].tolist()
200+
201+
# Remove EOS token if present
202+
if end_id in generated_tokens:
203+
eos_idx = generated_tokens.index(end_id)
204+
generated_tokens = generated_tokens[:eos_idx]
205+
206+
# Format output_ids as list of token chunks per step
207+
# This matches specdec_bench's expected format
208+
reformatted_output_ids = [[]]
209+
start = 0
210+
for accept_len in accept_length_list:
211+
if accept_len > 0 and start < len(generated_tokens):
212+
chunk = generated_tokens[start:start + accept_len]
213+
if chunk:
214+
reformatted_output_ids[0].append(chunk)
215+
start += accept_len
216+
217+
# Handle remaining tokens
218+
if start < len(generated_tokens):
219+
reformatted_output_ids[0].append(generated_tokens[start:])
220+
221+
output_dict['output_ids'] = reformatted_output_ids
222+
output_dict['output_logits'] = None
223+
output_dict['token_times'] = timing
224+
225+
return output_dict
226+
```
227+
228+
### Step 5: Implement stop() for Cleanup
229+
230+
```python
231+
def stop(self):
232+
"""Clean up resources."""
233+
# Clear any cached states
234+
if hasattr(self.model, "past_key_values"):
235+
del self.model.past_key_values
236+
del self.model.past_key_values_data
237+
del self.model.current_length_data
238+
239+
# Clear method-specific buffers
240+
if hasattr(self.model, "<method>_buffers"):
241+
del self.model.<method>_buffers
242+
243+
# Free GPU memory
244+
if hasattr(self, 'model') and self.model is not None:
245+
del self.model
246+
torch.cuda.empty_cache()
247+
```
248+
249+
### Step 6: Register the Model (Optional)
250+
251+
Add to `specdec_bench/specdec_bench/models/__init__.py`:
252+
253+
```python
254+
from .specbench_<method> import SpecBench<Method>Model
255+
```
256+
257+
## Key Differences Summary
258+
259+
| Aspect | Spec-Bench | specdec_bench |
260+
|--------|-----------|---------------|
261+
| Input format | `inputs.input_ids` (tensor from tokenizer) | `prompt_ids` (list of ints) |
262+
| Output format | `(output_ids, new_token, steps, accept_lengths)` | `dict` with `output_ids`, `output_logits`, `token_times` |
263+
| Output IDs | Full sequence tensor | List of token chunks per step |
264+
| Timing | External (in `run_eval`) | Internal (in `run()`) |
265+
| Device | `device_map="auto"` | Explicit single device |
266+
| Interface | Function-based | Class-based with async `run()` |
267+
268+
## Common Pitfalls
269+
270+
1. **Device Mismatch**: Avoid `device_map="auto"` which spreads model across GPUs. Use explicit `.to(device)`.
271+
272+
2. **Tensor vs List**: `prompt_ids` in specdec_bench is a Python list, not a tensor. Convert it in `run()`.
273+
274+
3. **Output Format**: specdec_bench expects `output_ids` as `[[chunk1, chunk2, ...]]` (list of lists of lists for beam_width=1).
275+
276+
4. **Timing**: Add `time.perf_counter()` calls to track per-step latency.
277+
278+
5. **EOS Handling**: Strip EOS tokens from output before formatting.
279+
280+
6. **Async Wrapper**: Use `asyncio.to_thread()` to wrap synchronous forward passes.
281+
282+
## Example: Mapping Spec-Bench Methods
283+
284+
| Spec-Bench File | Model Class | Forward Function | Key Utils |
285+
|-----------------|-------------|------------------|-----------|
286+
| `inference_medusa.py` | `MedusaModel` | `medusa_forward` | `generate_medusa_buffers`, `initialize_medusa` |
287+
| `inference_eagle.py` | `EaModel` | `ea_forward` | `generate_tree_buffers`, `initialize_tree` |
288+
| `inference_eagle2.py` | `EaModel` | `ea_forward` | Same as EAGLE |
289+
| `inference_hydra.py` | `HydraModel` | `hydra_forward` | `generate_hydra_buffers`, `initialize_hydra` |
290+
| `inference_lookahead.py` | `LookaheadModel` | `lookahead_forward` | Lookahead-specific utils |
291+
292+
## Testing Your Port
293+
294+
```python
295+
import asyncio
296+
297+
async def test():
298+
model = SpecBench<Method>Model(
299+
model_dir="/path/to/model",
300+
max_concurrent_requests=1,
301+
sampling_kwargs={"temperature": 0.0},
302+
# method-specific kwargs...
303+
)
304+
305+
result = await model.run(
306+
prompt_ids=[1, 2, 3, 4, 5], # Example token IDs
307+
max_length=100,
308+
end_id=2, # EOS token
309+
request_id="test",
310+
turn_id=0
311+
)
312+
313+
print("Output chunks:", result['output_ids'])
314+
print("Timing:", result['token_times'])
315+
316+
model.stop()
317+
318+
asyncio.run(test())
319+
```

0 commit comments

Comments
 (0)