diff --git a/README.md b/README.md
index a7e466d1..432844eb 100644
--- a/README.md
+++ b/README.md
@@ -636,12 +636,12 @@ Benchmark results for Qwen3-VL series models using Eagle3 speculative decoding o
##### 1.2.2 HunyuanOCR Model
-Benchmark results for HunyuanOCR using Eagle3 speculative decoding on vLLM (v0.13.0) across OCR tasks, using a single NVIDIA H20 GPU (**tp=1, ep=1, num_speculative_tokens=4, batch_size=1, output_len=1024**).
+Benchmark results for HunyuanOCR using Eagle3 speculative decoding on vLLM (v0.13.0) across **[OmniDocBench](https://huggingface.co/datasets/opendatalab/OmniDocBench)** dataset, using a single NVIDIA H20 GPU (**tp=1, ep=1, num_speculative_tokens=4, batch_size=1, output_len=1024**).
| Model |
Method |
- OCR-Bench-Internal |
+ OmniDocBench |
@@ -653,13 +653,13 @@ Benchmark results for HunyuanOCR using Eagle3 speculative decoding on vLLM (v0.1
| Hunyuan-OCR |
Vanilla |
- 71.21 |
+ 70.12 |
1 |
| Eagle3 |
- 120.75 |
- 2.2 |
+ 108.1 |
+ 2.08 |
diff --git a/README_cn.md b/README_cn.md
index a47839de..51fcdea3 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -640,13 +640,13 @@ bash scripts/deploy/lm_eval.sh -d 0,1 -t 2 -g 0.8 -r $RESULT_PATH -b "auto" --ta
##### 1.2.2 HunyuanOCR模型
-我们使用(v0.13.0)评测了HunyuanOCR Eagle3模型在 **OCR-Bench** 上的接收长度和吞吐。结果是在单张H20上用以下设置测得:**tp=1, ep=1, num_speculative_tokens=4, batch_size=1, output_len=1024**。
+我们使用(v0.13.0)评测了HunyuanOCR Eagle3模型在[OmniDocBench](https://huggingface.co/datasets/opendatalab/OmniDocBench)上的接收长度和吞吐。结果是在单张H20上用以下设置测得:**tp=1, ep=1, num_speculative_tokens=4, batch_size=1, output_len=1024**。
| Model |
Method |
- OCR-Bench-Internal |
+ OmniDocBench |
@@ -658,13 +658,13 @@ bash scripts/deploy/lm_eval.sh -d 0,1 -t 2 -g 0.8 -r $RESULT_PATH -b "auto" --ta
| Hunyuan-OCR |
Vanilla |
- 71.21 |
+ 70.12 |
1 |
| Eagle3 |
- 120.75 |
- 2.2 |
+ 108.1 |
+ 2.08 |
diff --git a/angelslim/compressor/speculative/train/configs/qwen2-audio-7b-eagle3.json b/angelslim/compressor/speculative/train/configs/qwen2-audio-7b-eagle3.json
new file mode 100644
index 00000000..ec89f657
--- /dev/null
+++ b/angelslim/compressor/speculative/train/configs/qwen2-audio-7b-eagle3.json
@@ -0,0 +1,30 @@
+{
+ "architectures": [
+ "Eagle3LlamaForCausalLM"
+ ],
+ "model_type": "llama",
+ "target_model_type": "qwen2_audio",
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 151643,
+ "dtype": "bfloat16",
+ "eos_token_id": 151645,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 11008,
+ "max_position_embeddings": 8192,
+ "num_attention_heads": 32,
+ "num_hidden_layers": 1,
+ "num_key_value_heads": 4,
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": null,
+ "rope_theta": 10000,
+ "use_cache": true,
+ "vocab_size": 156032,
+ "tie_word_embeddings": false,
+ "transformers_version": "4.57.1",
+ "draft_vocab_size": 32000,
+ "modal_type": "Audio"
+}
diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py
index 799e0933..591a309d 100644
--- a/angelslim/compressor/speculative/train/data/data_utils.py
+++ b/angelslim/compressor/speculative/train/data/data_utils.py
@@ -341,3 +341,71 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
[paddingtensor2D(item["position_ids"], max_length) for item in features]
)
return batch
+
+
+class AudioDataCollatorWithPadding:
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ max_length = max(item["input_ids"].shape[1] for item in features)
+ batch_input_ids = torch.cat(
+ [paddingtensor2D(item["input_ids"], max_length) for item in features]
+ )
+ batch_attention_mask = torch.cat(
+ [paddingtensor2D(item["attention_mask"], max_length) for item in features]
+ )
+ batch_loss_mask = torch.cat(
+ [paddingtensor2D(item["loss_mask"], max_length) for item in features]
+ )
+
+ batch = {
+ "input_ids": batch_input_ids,
+ "attention_mask": batch_attention_mask,
+ "loss_mask": batch_loss_mask,
+ "feature_attention_mask": None,
+ "input_features": None,
+ "hidden_states": None,
+ "target_hiddens": None,
+ "inputs_embeds": None,
+ "position_ids": None,
+ }
+
+ # Check if both hidden_states and target_hiddens exist in all features
+ if all(
+ "hidden_states" in item and "target_hiddens" in item for item in features
+ ):
+ batch["hidden_states"] = torch.cat(
+ [paddingtensor(item["hidden_states"], max_length) for item in features]
+ )
+ batch["target_hiddens"] = torch.cat(
+ [paddingtensor(item["target_hiddens"], max_length) for item in features]
+ )
+ if all(
+ "inputs_embeds" in item and item["inputs_embeds"] is not None
+ for item in features
+ ):
+ batch["inputs_embeds"] = torch.cat(
+ [paddingtensor(item["inputs_embeds"], max_length) for item in features]
+ )
+ if all(
+ "position_ids" in item and item["position_ids"] is not None
+ for item in features
+ ):
+ batch["position_ids"] = torch.cat(
+ [paddingtensor2D(item["position_ids"], max_length) for item in features]
+ )
+ if all(
+ "feature_attention_mask" in item
+ and item["feature_attention_mask"] is not None
+ for item in features
+ ):
+ batch["feature_attention_mask"] = torch.cat(
+ [(item["feature_attention_mask"]) for item in features]
+ )
+ if all(
+ "input_features" in item and item["input_features"] is not None
+ for item in features
+ ):
+ batch["input_features"] = torch.cat(
+ [(item["input_features"]) for item in features]
+ )
+ return batch
diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py
index dd208cfe..8b0501aa 100644
--- a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py
+++ b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py
@@ -19,6 +19,7 @@
OfflineVLMHunyuanVLDatasetBuilder,
)
from .online_dataset_builder import (
+ OnlineAudioDatasetBuilder,
OnlineLLMDatasetBuilder,
OnlineVLMDatasetBuilder,
OnlineVLMHunyuanVLDatasetBuilder,
@@ -32,4 +33,5 @@
"OfflineVLMDatasetBuilder",
"OfflineVLMHunyuanVLDatasetBuilder",
"DatasetBuilderFactory",
+ "OnlineAudioDatasetBuilder",
]
diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py
index d25c8711..3a160161 100644
--- a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py
+++ b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py
@@ -14,16 +14,19 @@
from typing import Any, Dict, List, Optional, Union
+import requests
import torch
from datasets import Features, Value, load_dataset
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
+from transformers.pipelines.audio_utils import ffmpeg_read
from angelslim.utils import rank0_print
from ..chat_templates import ChatTemplateType
from ..data_utils import (
+ AudioDataCollatorWithPadding,
DataCollatorWithPadding,
VLMDataCollatorWithPadding,
VLMHunyuanDataCollatorWithPadding,
@@ -539,3 +542,271 @@ def _extract_vision_info(self, messages: List[Dict]) -> tuple:
video_paths.append(item["video"])
return image_paths, video_paths
+
+
+@DatasetBuilderFactory.register("online", "Audio", "qwen2_audio")
+class OnlineAudioDatasetBuilder(OnlineDatasetBuilder):
+ def __init__(
+ self,
+ tokenizer: Union[AutoTokenizer, AutoProcessor],
+ max_length: int = 2048,
+ shuffle_seed: int = 42,
+ chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3,
+ display: bool = False,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ tokenizer,
+ max_length,
+ shuffle_seed,
+ chat_template_type,
+ display,
+ )
+
+ def build_dataset(
+ self,
+ datapath: str,
+ num_proc: int = 8,
+ shuffle: bool = True,
+ sample_num: Optional[int] = None,
+ ) -> Dataset:
+ try:
+ # Load dataset
+ features = Features(
+ {
+ "id": Value("string"),
+ "conversations": [
+ {
+ "role": Value("string"),
+ "content": [
+ {
+ "type": Value("string"),
+ "text": Value("string"),
+ "audio": Value("string"),
+ }
+ ],
+ }
+ ],
+ }
+ )
+ ds = load_dataset("json", data_files=datapath, features=features)
+
+ # Conditionally shuffle dataset
+ if shuffle:
+ ds = ds["train"].shuffle(seed=self.shuffle_seed)
+ else:
+ ds = ds["train"]
+
+ if sample_num is not None and 0 < sample_num < len(ds):
+ ds = ds.select(range(sample_num))
+
+ # Store original columns for removal
+ original_columns = ds.column_names
+
+ # Apply preprocessing
+ processed_ds = ds.map(
+ self._preprocess_function,
+ batched=True,
+ num_proc=num_proc,
+ remove_columns=original_columns,
+ load_from_cache_file=False,
+ desc="Processing conversations",
+ )
+
+ # Filter out None results with multiprocessing support
+ processed_ds = processed_ds.filter(
+ lambda batch: [ids is not None for ids in batch["input_ids"]],
+ batched=True,
+ num_proc=num_proc,
+ desc="Filtering empty input_ids",
+ )
+ processed_ds.set_format(type="torch")
+
+ return processed_ds
+
+ except Exception as e:
+ raise RuntimeError(f"Dataset building failed for {datapath}") from e
+
+ def get_data_collator(self) -> Any:
+ return AudioDataCollatorWithPadding()
+
+ def read_audio(self, audio_path):
+ if audio_path.startswith("http://") or audio_path.startswith("https://"):
+ inputs = requests.get(audio_path).content
+ else:
+ with open(audio_path, "rb") as f:
+ inputs = f.read()
+ return inputs
+
+ def _preprocess_function(self, examples: Dict[str, List]) -> Dict[str, List]:
+ new_examples = {
+ "input_ids": [],
+ "attention_mask": [],
+ "loss_mask": [],
+ "input_features": [],
+ "feature_attention_mask": [],
+ }
+
+ for i in range(len(examples["id"])):
+ try:
+ processed_example = self._process_single_conversation(
+ examples["conversations"][i]
+ )
+
+ if processed_example is not None:
+ for key in new_examples.keys():
+ if key not in processed_example:
+ new_examples[key].append(None)
+ else:
+ new_examples[key].append(processed_example[key])
+
+ except Exception as e:
+ rank0_print(f"Error processing example: {e}")
+ # Add None placeholders to maintain batch consistency
+ for key in new_examples:
+ new_examples[key].append(None)
+
+ cleaned_new_examples = {}
+ for key, value in new_examples.items():
+ if any(v is not None for v in value):
+ cleaned_new_examples[key] = value
+
+ return cleaned_new_examples
+
+ def _visualize_loss_mask(
+ self, input_ids: torch.Tensor, loss_mask: torch.Tensor, conversation: str
+ ) -> None:
+ """
+ Visualize loss_mask with color-coded output.
+
+ Args:
+ input_ids: Token IDs
+ loss_mask: Loss mask tensor (1 for training, 0 for ignoring)
+ conversation: Original conversation text
+ """
+ input_ids = input_ids.view(-1)
+ return super()._visualize_loss_mask(input_ids, loss_mask, conversation)
+
+ def _create_loss_mask_from_offsets(
+ self, conversation: str, offsets: torch.Tensor
+ ) -> torch.Tensor:
+ if offsets.ndim == 3:
+ offsets = offsets[0]
+ return super()._create_loss_mask_from_offsets(conversation, offsets)
+
+ def _extract_audio_info(self, messages: List[Dict]) -> tuple:
+ """Extract Audio paths from messages"""
+ audio_paths = []
+
+ sampling_rate = self.tokenizer.feature_extractor.sampling_rate
+ for message in messages:
+ content = message.get("content", [])
+ if not isinstance(content, list):
+ continue
+
+ for item in content:
+ if item.get("type") == "audio":
+ # Handle both file paths and PIL images
+ if isinstance(item["audio"], str):
+ try:
+ audio_paths.append(
+ ffmpeg_read(
+ self.read_audio(item["audio"]),
+ sampling_rate=sampling_rate,
+ )
+ )
+ except ValueError as e:
+ raise ValueError(
+ f"Could not open audio file: {item['audio']}, {e}"
+ )
+ return audio_paths
+
+ def _process_single_conversation(
+ self, conversation_data: List[Dict]
+ ) -> Optional[Dict]:
+ if not conversation_data or not isinstance(conversation_data, list):
+ return None
+
+ try:
+ # Build messages with system prompt
+ messages = self._build_messages(conversation_data)
+ if not messages:
+ return None
+
+ # Apply chat template
+ assert isinstance(
+ messages, list
+ ), f"type(messages)={type(messages)} is not list"
+ for message in messages:
+ if isinstance(message["content"], str):
+ continue
+ assert isinstance(
+ message["content"], list
+ ), f"content={type(message['content'])} is not str or list"
+ new_content = []
+ for item in message["content"]:
+ new_item = {"type": item["type"], item["type"]: item[item["type"]]}
+ new_content.append(new_item)
+ del message["content"]
+ message["content"] = new_content
+
+ input_text = self.tokenizer.apply_chat_template(
+ messages, add_generation_prompt=False, tokenize=False
+ )
+ input_audios = self._extract_audio_info(messages)
+
+ # cannot set max_length,
+ # otherwise the input_ids audio token length will be aligned(missing)
+ encoding = self.tokenizer(
+ text=input_text,
+ audio=input_audios,
+ sampling_rate=self.tokenizer.feature_extractor.sampling_rate,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ truncation=True,
+ padding=False,
+ )
+ input_ids = encoding["input_ids"]
+ offsets = encoding["offset_mapping"]
+
+ conversation = self.tokenizer.decode(
+ input_ids[0], skip_special_tokens=False
+ )
+
+ # Create loss mask for assistant responses
+ try:
+ loss_mask = self._create_loss_mask_from_offsets(conversation, offsets)
+ except Exception as e:
+ rank0_print(f"Error creating loss mask: {e}")
+ rank0_print(f"offsets: {offsets}")
+ raise e
+ attention_mask = torch.ones_like(input_ids)
+
+ # Visualize loss mask if display mode is enabled
+ if self.display and self.display_count == 0:
+ try:
+ self._visualize_loss_mask(input_ids, loss_mask, conversation)
+ except Exception as e:
+ rank0_print(f"Error visualizing loss mask: {e}")
+ rank0_print(f"input_ids: {input_ids}, loss_mask: {loss_mask}")
+ raise e
+ self.display_count += 1
+
+ result_dict = {
+ "input_ids": input_ids.view(1, -1),
+ "attention_mask": attention_mask.view(1, -1),
+ "loss_mask": loss_mask.view(1, -1),
+ }
+
+ if "input_features" in encoding:
+ result_dict["input_features"] = encoding["input_features"]
+ if "feature_attention_mask" in encoding:
+ result_dict["feature_attention_mask"] = encoding[
+ "feature_attention_mask"
+ ]
+
+ return result_dict
+
+ except Exception as e:
+ rank0_print(f"Error processing conversation: {e}")
+ return None
diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py
index a7bf2d43..c3087186 100644
--- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py
+++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py
@@ -286,7 +286,15 @@ def get_aux_and_target_hiddens(
class VLMTransformersBackend(BaseBackend):
"""VLM HuggingFace Transformers backend"""
+ SUPPORT_MODEL_TYPE = ["hunyuan_vl", "qwen3_vl"]
+
def load_model(self):
+ if (
+ self.target_model_type is None
+ or self.target_model_type not in self.SUPPORT_MODEL_TYPE
+ ):
+ raise ValueError(f"{self.target_model_type} is not supported now!")
+
if self.target_model_type == "hunyuan_vl":
from transformers import AutoProcessor, HunYuanVLForConditionalGeneration
@@ -305,7 +313,7 @@ def load_model(self):
self.tokenizer = AutoProcessor.from_pretrained(
self.model_path, trust_remote_code=True
)
- else:
+ elif self.target_model_type == "qwen3_vl":
from transformers import AutoModelForImageTextToText, AutoProcessor
device = decide_device_for_distributed()
@@ -328,6 +336,8 @@ def load_model(self):
self.model_path,
trust_remote_code=True,
)
+ else:
+ raise ValueError(f"Unsupported target model type: {self.target_model_type}")
def _prepare_model_kwargs(self, device: str) -> dict:
"""
@@ -517,6 +527,201 @@ def hook(module, args, kwargs):
}
+class AudioTransformersBackend(BaseBackend):
+ """Audio HuggingFace Transformers backend"""
+
+ SUPPORT_MODEL_TYPE = ["qwen2_audio"]
+
+ def load_model(self):
+ if (
+ self.target_model_type is None
+ or self.target_model_type not in self.SUPPORT_MODEL_TYPE
+ ):
+ raise ValueError(f"{self.target_model_type} is not supported now!")
+
+ if self.target_model_type == "qwen2_audio":
+ from transformers import (
+ Qwen2AudioForConditionalGeneration,
+ Qwen2AudioProcessor,
+ )
+
+ device = decide_device_for_distributed()
+ print_with_rank(f"Loading model to device: {device}")
+
+ # Prepare model loading configuration
+ model_kwargs = self._prepare_model_kwargs(device)
+
+ self.model = Qwen2AudioForConditionalGeneration.from_pretrained(
+ self.model_path, **model_kwargs
+ )
+
+ # Freeze the base model
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.model.eval()
+
+ self.tokenizer = Qwen2AudioProcessor.from_pretrained(
+ self.model_path, trust_remote_code=True
+ )
+ else:
+ raise ValueError(f"Unsupported target model type: {self.target_model_type}")
+
+ def _prepare_model_kwargs(self, device: str) -> dict:
+ """
+ Prepare keyword arguments for model loading.
+
+ Args:
+ device: Target device for model placement
+
+ Returns:
+ Dictionary of model loading arguments
+ """
+ default_kwargs = {
+ "dtype": torch.bfloat16,
+ "device_map": device,
+ "trust_remote_code": True,
+ }
+ default_kwargs.update(self.kwargs)
+ return default_kwargs
+
+ def get_hidden_states_and_logits(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Extract hidden states and logits using Transformers backend.
+
+ Args:
+ input_ids: Input token IDs
+ attention_mask: Attention mask
+ **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers
+
+ Returns:
+ Tuple of (concatenated_hidden_states, logits)
+ """
+ inputs_embeds_list, position_ids_list = [], []
+
+ def hook(module, args, kwargs):
+ if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
+ inputs_embeds_list.append(
+ kwargs["inputs_embeds"].clone().detach().cpu()
+ )
+ if "position_ids" in kwargs and kwargs["position_ids"] is not None:
+ position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
+ return args, kwargs
+
+ handle = self.model.language_model.register_forward_pre_hook(
+ hook, with_kwargs=True
+ )
+ input_features = kwargs.get("input_features", None)
+ feature_attention_mask = kwargs.get("feature_attention_mask", None)
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ input_features=input_features,
+ feature_attention_mask=feature_attention_mask,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ handle.remove()
+
+ inputs_embeds = (
+ inputs_embeds_list[0].to(input_ids.device) if inputs_embeds_list else None
+ )
+ position_ids = (
+ position_ids_list[0].to(input_ids.device) if position_ids_list else None
+ )
+
+ # Extract auxiliary hidden states
+ aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None)
+ hidden_states = self._extract_auxiliary_hidden_states(
+ outputs.hidden_states, aux_layer_ids
+ )
+
+ # Return hidden states and logits on the same device as input
+ return (
+ hidden_states,
+ outputs.logits.to(input_ids.device),
+ inputs_embeds,
+ position_ids,
+ )
+
+ def get_aux_and_target_hiddens(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict[str, torch.Tensor]:
+ """
+ Extract auxiliary and final layer hidden states.
+
+ Args:
+ input_ids: Input token IDs
+ attention_mask: Attention mask
+ **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers
+
+ Returns:
+ Tuple of (auxiliary_hidden_states, final_hidden_states)
+ """
+ inputs_embeds_list, position_ids_list = [], []
+
+ def hook(module, args, kwargs):
+ if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
+ inputs_embeds_list.append(
+ kwargs["inputs_embeds"].clone().detach().cpu()
+ )
+ if "position_ids" in kwargs and kwargs["position_ids"] is not None:
+ position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
+ return args, kwargs
+
+ handle = self.model.language_model.register_forward_pre_hook(
+ hook, with_kwargs=True
+ )
+ input_features = kwargs.get("input_features", None)
+ feature_attention_mask = kwargs.get("feature_attention_mask", None)
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ input_features=input_features,
+ feature_attention_mask=feature_attention_mask,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ handle.remove()
+ inputs_embeds = (
+ inputs_embeds_list[0].to(input_ids.device) if inputs_embeds_list else None
+ )
+ position_ids = (
+ position_ids_list[0].to(input_ids.device) if position_ids_list else None
+ )
+
+ # Extract auxiliary hidden states
+ aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None)
+ aux_hidden_states = self._extract_auxiliary_hidden_states(
+ outputs.hidden_states, aux_layer_ids
+ )
+
+ # Get final layer hidden states
+ target_hidden_states = outputs.hidden_states[-1]
+
+ # hidden_states: B, N, 3*D
+ # target_hiddens: B, N, D
+ # inputs_embeds: B, N, D
+ # position_ids: 3, N
+ return {
+ "hidden_states": aux_hidden_states,
+ "target_hiddens": target_hidden_states,
+ "inputs_embeds": inputs_embeds,
+ "position_ids": position_ids,
+ }
+
+
class TargetModelWrapper:
"""
Unified wrapper for target models in Eagle3 training.
@@ -544,6 +749,7 @@ class TargetModelWrapper:
BACKENDS = {
("hf", "LLM"): TransformersBackend,
("hf", "VLM"): VLMTransformersBackend,
+ ("hf", "Audio"): AudioTransformersBackend,
}
def __init__(
diff --git a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py
index 74af8c8c..e0a91d6d 100644
--- a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py
+++ b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py
@@ -147,3 +147,74 @@ def prepare_data_for_draft_model(self, inputs):
"position_ids": position_ids,
"attention_mask": attention_mask,
}
+
+
+@Eagle3TrainerFactory.register("online", "Audio")
+class OnlineAudioEagle3Trainer(Eagle3Trainer):
+ """
+ Online EAGLE3 Trainer for speculative decoding training.
+ Implements training logic for EAGLE3 model using a draft model to predict
+ tokens based on hidden states from a target model.
+ """
+
+ def __init__(
+ self,
+ draft_model: nn.Module,
+ target_model: nn.Module,
+ length: int,
+ draft_model_config: Dict[str, Any],
+ **kwargs,
+ ):
+ """
+ Initialize the OnlineEagle3Trainer.
+ Args:
+ draft_model: Draft model for token prediction
+ target_model: Target model for generating hidden states
+ length: Number of speculative decoding steps
+ draft_model_config: Configuration dictionary for draft model
+ **kwargs: Additional arguments passed to parent Trainer
+ """
+ super().__init__(draft_model=draft_model, length=length, **kwargs)
+ self.target_model = target_model
+ self._aux_hidden_states_layer_ids = getattr(
+ draft_model_config, "aux_hidden_states_layer_ids", None
+ )
+
+ def prepare_data_for_draft_model(self, inputs):
+ input_ids = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
+ loss_mask = inputs["loss_mask"]
+
+ kwargs = {
+ k: v
+ for k, v in inputs.items()
+ if k not in ["input_ids", "attention_mask", "loss_mask"]
+ }
+ # Get hidden states and logits from target model
+ hidden_states, target_logits, _, position_ids = (
+ self.target_model.get_hidden_states_and_logits(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ aux_hidden_states_layer_ids=self._aux_hidden_states_layer_ids,
+ **kwargs,
+ )
+ )
+
+ # Apply right padding and move tensors to correct device
+ target_logits = padding(target_logits, left=False).to(input_ids.device)
+ input_ids = padding(input_ids, left=False)
+ loss_mask = loss_mask[..., None].to(input_ids.device)
+
+ result_dict = {}
+ result_dict.update(kwargs)
+ result_dict.update(
+ {
+ "hidden_states": hidden_states,
+ "target_logits": target_logits,
+ "input_ids": input_ids,
+ "position_ids": position_ids,
+ "loss_mask": loss_mask,
+ "attention_mask": attention_mask,
+ }
+ )
+ return result_dict
diff --git a/dataset/librispeech_test/audios/1255-90413-0010.flac b/dataset/librispeech_test/audios/1255-90413-0010.flac
new file mode 100644
index 00000000..4aefb854
Binary files /dev/null and b/dataset/librispeech_test/audios/1255-90413-0010.flac differ
diff --git a/dataset/librispeech_test/audios/1580-141083-0008.flac b/dataset/librispeech_test/audios/1580-141083-0008.flac
new file mode 100644
index 00000000..1b39a5e9
Binary files /dev/null and b/dataset/librispeech_test/audios/1580-141083-0008.flac differ
diff --git a/dataset/librispeech_test/audios/1995-1837-0019.flac b/dataset/librispeech_test/audios/1995-1837-0019.flac
new file mode 100644
index 00000000..3e6e4a78
Binary files /dev/null and b/dataset/librispeech_test/audios/1995-1837-0019.flac differ
diff --git a/dataset/librispeech_test/audios/2803-154328-0004.flac b/dataset/librispeech_test/audios/2803-154328-0004.flac
new file mode 100644
index 00000000..23473006
Binary files /dev/null and b/dataset/librispeech_test/audios/2803-154328-0004.flac differ
diff --git a/dataset/librispeech_test/audios/5694-64029-0007.flac b/dataset/librispeech_test/audios/5694-64029-0007.flac
new file mode 100644
index 00000000..d19132a7
Binary files /dev/null and b/dataset/librispeech_test/audios/5694-64029-0007.flac differ
diff --git a/dataset/librispeech_test/audios/61-70970-0032.flac b/dataset/librispeech_test/audios/61-70970-0032.flac
new file mode 100644
index 00000000..b6a4de64
Binary files /dev/null and b/dataset/librispeech_test/audios/61-70970-0032.flac differ
diff --git a/dataset/librispeech_test/audios/6267-53049-0027.flac b/dataset/librispeech_test/audios/6267-53049-0027.flac
new file mode 100644
index 00000000..ed70cbd3
Binary files /dev/null and b/dataset/librispeech_test/audios/6267-53049-0027.flac differ
diff --git a/dataset/librispeech_test/audios/6345-93306-0021.flac b/dataset/librispeech_test/audios/6345-93306-0021.flac
new file mode 100644
index 00000000..da137856
Binary files /dev/null and b/dataset/librispeech_test/audios/6345-93306-0021.flac differ
diff --git a/dataset/librispeech_test/audios/700-122867-0027.flac b/dataset/librispeech_test/audios/700-122867-0027.flac
new file mode 100644
index 00000000..55bc3604
Binary files /dev/null and b/dataset/librispeech_test/audios/700-122867-0027.flac differ
diff --git a/dataset/librispeech_test/audios/8188-269288-0038.flac b/dataset/librispeech_test/audios/8188-269288-0038.flac
new file mode 100644
index 00000000..3ee3855b
Binary files /dev/null and b/dataset/librispeech_test/audios/8188-269288-0038.flac differ
diff --git a/dataset/librispeech_test/librispeech_eval_10_test.jsonl b/dataset/librispeech_test/librispeech_eval_10_test.jsonl
new file mode 100644
index 00000000..e2dd608d
--- /dev/null
+++ b/dataset/librispeech_test/librispeech_eval_10_test.jsonl
@@ -0,0 +1,10 @@
+{"id": 5910, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/1580-141083-0008.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "THE PROOF WAS IN THREE LONG SLIPS I HAD LEFT THEM ALL TOGETHER"}]}]}
+{"id": 4540, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/6267-53049-0027.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "PENELOPE WAS FOUR YEARS OLDER THAN I WAS BUT WE WERE DEVOTED TO EACH OTHER"}]}]}
+{"id": 984, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/2803-154328-0004.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "IF IT IS DECREED THAT WE DIE TO MORROW LET US DIE BRAVELY LIKE CHRISTIAN MEN READY TO APPEAR WITHOUT TERROR BEFORE THE SUPREME JUDGE"}]}]}
+{"id": 7344, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/61-70970-0032.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "ENQUIRED ROBIN WITH HIS SUSPICIONS STILL UPON HIM"}]}]}
+{"id": 10899, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/8188-269288-0038.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "I MEAN THAT I DON'T WANT YOU TO BEGIN TO ASK QUESTIONS"}]}]}
+{"id": 2894, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/1255-90413-0010.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "JUST AS THEY TURN MADEIRA INTO PORT IN THE SPACE OF A SINGLE NIGHT SO THIS OLD AIR HAS BEEN TAKEN AND DOCTORED AND TWISTED ABOUT AND BROUGHT OUT AS A NEW POPULAR DITTY INDEED"}]}]}
+{"id": 1694, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/5694-64029-0007.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "I SOON FOUND OUT THAT HE HAD CAUGHT SIGHT OF THE RELIEF ON THE ROAD AND WAS AFRAID TO SHOOT I QUICKLY MADE UP MY MIND"}]}]}
+{"id": 6068, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/1995-1837-0019.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "HE SAT DOWN WEAK BEWILDERED AND ONE THOUGHT WAS UPPERMOST ZORA"}]}]}
+{"id": 5021, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/700-122867-0027.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "IT SEEMS SUCH A TRAGIC THING"}]}]}
+{"id": 2201, "conversations": [{"role": "user", "content": [{"type": "audio", "audio": "./audios/6345-93306-0021.flac"}, {"type": "text", "text": "Detect the language and recognize the speech: <|en|>"}]}, {"role": "assistant", "content": [{"type": "text", "text": "ALL THE SAME HE ADDED IRRELEVANTLY YOU SHALL HAVE THE LATCH KEY"}]}]}
\ No newline at end of file
diff --git a/docs/source/assets/speculative_decoding/eagle3_speedup_and_accepted_length.png b/docs/source/assets/speculative_decoding/eagle3_speedup_and_accepted_length.png
index ccaf1bf0..0a658950 100644
Binary files a/docs/source/assets/speculative_decoding/eagle3_speedup_and_accepted_length.png and b/docs/source/assets/speculative_decoding/eagle3_speedup_and_accepted_length.png differ
diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt
index 0c073e6f..621f664f 100644
--- a/requirements/requirements_multimodal.txt
+++ b/requirements/requirements_multimodal.txt
@@ -1,2 +1,3 @@
qwen_vl_utils==0.0.11
qwen_omni_utils
+mistral_common
\ No newline at end of file
diff --git a/scripts/speculative/qwen2_audio/train_eagle3_audio_online.sh b/scripts/speculative/qwen2_audio/train_eagle3_audio_online.sh
new file mode 100644
index 00000000..3f91e1be
--- /dev/null
+++ b/scripts/speculative/qwen2_audio/train_eagle3_audio_online.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+export CONFIG_DIR=angelslim/compressor/speculative/train/configs
+export TARGET_MODEL_NAME_OR_PATH=Qwen/Qwen2-Audio-7B-Instruct
+export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/qwen2-audio-7b-eagle3.json
+export TRAIN_DATA_PATH=
+export EVAL_DATA_PATH=
+export OUTPUT_DIR=
+export EMBED_WEIGHT_KEY="language_model.model.embed_tokens.weight"
+export MODEL_MAX_LENGTH=4096
+export CHAT_TEMPLATE_TYPE=qwen2_audio
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+
+torchrun --nproc_per_node=8 tools/train_eagle3_online.py \
+ --modal_type Audio \
+ --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \
+ --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \
+ --train_data_path $TRAIN_DATA_PATH \
+ --eval_data_path $EVAL_DATA_PATH \
+ --output_dir $OUTPUT_DIR \
+ --num_train_epochs 20 \
+ --per_device_train_batch_size 1 \
+ --per_device_eval_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --num_proc 4 \
+ --save_strategy "steps" \
+ --eval_strategy "steps" \
+ --save_steps 2000 \
+ --eval_steps 2000 \
+ --learning_rate 1e-4 \
+ --weight_decay 0.0 \
+ --warmup_ratio 0.1 \
+ --lr_scheduler_type "constant" \
+ --logging_steps 20 \
+ --model_max_length $MODEL_MAX_LENGTH \
+ --embed_weight_key $EMBED_WEIGHT_KEY \
+ --deepspeed $CONFIG_DIR/deepspeed_zero3.json \
+ --chat_template_type $CHAT_TEMPLATE_TYPE \
+ --report_to wandb \
+ --run_name qwen2-audio-7b-instruct-eagle3
diff --git a/tools/train_eagle3_online.py b/tools/train_eagle3_online.py
index ca09237e..02b1b2e2 100644
--- a/tools/train_eagle3_online.py
+++ b/tools/train_eagle3_online.py
@@ -40,7 +40,7 @@ def parse_args():
"--modal_type",
type=str,
default="LLM",
- choices=["LLM", "VLM"],
+ choices=["LLM", "VLM", "Audio"],
help="Modal type: LLM for language models, VLM for vision-language models",
)
model_group.add_argument(
diff --git a/tools/vllm_offline_eagle3_qwen2_audio_bench.py b/tools/vllm_offline_eagle3_qwen2_audio_bench.py
new file mode 100644
index 00000000..25d4ac2d
--- /dev/null
+++ b/tools/vllm_offline_eagle3_qwen2_audio_bench.py
@@ -0,0 +1,292 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+usage:
+python3 ./tools/vllm_offline_eagle3_qwen2_audio_bench.py \
+ --target_model "$MODEL_DIR" \
+ --draft_model "$EAGLE_DIR" \
+ --use_eagle \
+ --num_spec_tokens 4 \
+ --num_prompts 10 \
+ --temp 0 \
+ --max_num_seqs 1 \
+ --output_len 1024 \
+ --output_file "$OUTPUT_FILE"
+"""
+
+import argparse
+import os
+import time
+from dataclasses import asdict
+from datetime import datetime
+from typing import Any, NamedTuple
+
+from mistral_common.audio import Audio
+from vllm import LLM, EngineArgs, SamplingParams
+from vllm.v1.metrics.reader import Counter, Vector
+
+
+class ModelRequestData(NamedTuple):
+ engine_args: EngineArgs
+ prompt: str | None = None
+ prompt_token_ids: dict[str, list[int]] | None = None
+ multi_modal_data: dict[str, Any] | None = None
+ stop_token_ids: list[int] | None = None
+
+
+# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
+# lower-end GPUs.
+# Unless specified, these settings have been tested to work on a single L4.
+
+
+# Qwen2-Audio
+def run_qwen2_audio(args, question: str, audio_count: int) -> ModelRequestData:
+ num_spec_tokens = args.num_spec_tokens
+
+ speculative_config = None
+ if args.use_eagle:
+ if args.draft_model:
+ speculative_config = {
+ "method": "eagle3",
+ "model": args.draft_model,
+ "num_speculative_tokens": num_spec_tokens,
+ "prefill_token_shift": False,
+ }
+ else:
+ print(
+ "Warning: use_eagle is set but no draft_model provided. "
+ "Running without speculative decoding."
+ )
+
+ engine_args = EngineArgs(
+ model=args.target_model,
+ max_model_len=args.max_model_len,
+ max_num_seqs=args.max_num_seqs,
+ limit_mm_per_prompt={"audio": audio_count},
+ speculative_config=speculative_config,
+ enforce_eager=True,
+ )
+
+ audio_in_prompt = "".join(
+ [
+ f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
+ for idx in range(audio_count)
+ ]
+ )
+
+ prompt = (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ f"{audio_in_prompt}{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ )
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--target_model", type=str, default=None, help="Path to target model"
+ )
+ parser.add_argument(
+ "--draft_model", type=str, default=None, help="Path to draft model"
+ )
+ parser.add_argument(
+ "--use_eagle",
+ action="store_true",
+ help="Enable speculative decoding with Eagle",
+ )
+ parser.add_argument(
+ "--num_spec_tokens", type=int, default=2, help="Number of speculative tokens"
+ )
+ parser.add_argument("--max_num_seqs", type=int, default=1)
+ parser.add_argument("--max_model_len", type=int, default=1024)
+ parser.add_argument(
+ "--num_prompts", type=int, default=100, help="Number of prompts to run"
+ )
+ parser.add_argument("--output_file", type=str, default="None", help="Output file")
+ parser.add_argument("--temp", type=float, default=0, help="./results")
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
+ parser.add_argument("--output_len", type=int, default=1024)
+ parser.add_argument(
+ "--tp",
+ type=int,
+ default=1,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
+ parser.add_argument(
+ "--test_data_path",
+ type=str,
+ default="dataset/librispeech_test/librispeech_eval_10_test.jsonl",
+ help="Dataset to run",
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+ if args.tp is not None and args.tp < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, " f"got {args.tp}"
+ )
+ audio_count = 1
+
+ req_data = run_qwen2_audio(
+ args, "Transcribe speech to text. <|en|>", audio_count=audio_count
+ )
+
+ # Disable other modalities to save memory
+ default_limits = {"image": 0, "video": 0, "audio": 0}
+ req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
+ req_data.engine_args.limit_mm_per_prompt or {}
+ )
+
+ engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ if args.tp is not None:
+ engine_args["tensor_parallel_size"] = args.tp
+ llm = LLM(**engine_args)
+
+ sampling_params = SamplingParams(
+ temperature=args.temp,
+ max_tokens=args.output_len,
+ stop_token_ids=req_data.stop_token_ids,
+ )
+
+ import json
+
+ inputs_list = []
+ num_prompts = args.num_prompts
+ with open(args.test_data_path, "r", encoding="utf-8") as f:
+ for line_num, line in enumerate(f, start=1):
+ if line_num > num_prompts:
+ break
+ data_line = json.loads(line)
+
+ mm_data = req_data.multi_modal_data
+ if not mm_data:
+ mm_data = {}
+ if audio_count > 0:
+ audio_path = data_line["conversations"][0]["content"][0]["audio"]
+ audio_path = os.path.join(
+ os.path.dirname(args.test_data_path), audio_path
+ )
+ mm_data = {"audio": [Audio.from_file(audio_path).audio_array]}
+ inputs = {"multi_modal_data": mm_data}
+
+ if req_data.prompt:
+ inputs["prompt"] = req_data.prompt
+ else:
+ inputs["prompt_token_ids"] = req_data.prompt_token_ids
+
+ inputs_list.append(inputs)
+
+ tic = time.perf_counter()
+ outputs = llm.generate(
+ inputs_list,
+ sampling_params=sampling_params,
+ )
+ latency = time.perf_counter() - tic
+
+ for o in outputs:
+ generated_text = o.outputs[0].text
+ print(generated_text)
+
+ try:
+ metrics = llm.get_metrics()
+ except AssertionError:
+ print("Metrics are not supported in the V0 engine.")
+ return None
+
+ total_num_output_tokens = sum(
+ len(output.outputs[0].token_ids) for output in outputs
+ )
+ num_drafts = 0
+ num_draft_tokens = 0
+ num_accepted_tokens = 0
+ num_spec_tokens = args.num_spec_tokens
+ acceptance_counts = [0] * num_spec_tokens
+
+ for metric in metrics:
+ if metric.name == "vllm:spec_decode_num_drafts":
+ assert isinstance(metric, Counter)
+ num_drafts += metric.value
+ elif metric.name == "vllm:spec_decode_num_draft_tokens":
+ assert isinstance(metric, Counter)
+ num_draft_tokens += metric.value
+ elif metric.name == "vllm:spec_decode_num_accepted_tokens":
+ assert isinstance(metric, Counter)
+ num_accepted_tokens += metric.value
+ elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
+ assert isinstance(metric, Vector)
+ for pos in range(len(metric.values)):
+ acceptance_counts[pos] += metric.values[pos]
+
+ output_throughput = total_num_output_tokens / latency
+ acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
+
+ # Calculate acceptance rate at each position
+ acceptance_rates = {}
+ for i in range(len(acceptance_counts)):
+ acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
+ acceptance_rates[f"acceptance_rate_pos_{i}"] = round(acceptance_rate, 4)
+
+ # Prepare statistics dictionary
+ stats = {
+ "timestamp": datetime.now().isoformat(),
+ "num_spec_tokens": num_spec_tokens,
+ "total_num_output_tokens": total_num_output_tokens,
+ "latency_seconds": round(latency, 2),
+ "output_throughput_tokens_per_sec": round(output_throughput, 2),
+ "num_drafts": num_drafts,
+ "num_draft_tokens": num_draft_tokens,
+ "num_accepted_tokens": num_accepted_tokens,
+ "mean_acceptance_length": round(acceptance_length, 4),
+ **acceptance_rates,
+ }
+
+ # Print statistics
+ print("-" * 50)
+ print(f"total_num_output_tokens: {total_num_output_tokens}")
+ print(f"latency: {latency:.2f} s")
+ print(f"output_throughput: {output_throughput:.2f} tokens/s")
+ print(f"num_drafts: {num_drafts}")
+ print(f"num_draft_tokens: {num_draft_tokens}")
+ print(f"num_accepted_tokens: {num_accepted_tokens}")
+ print(f"mean acceptance length: {acceptance_length:.2f}")
+ print("-" * 50)
+
+ # Print acceptance rate at each token position
+ for i in range(len(acceptance_counts)):
+ acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
+ print(f"acceptance at token {i}: {acceptance_rate:.2f}")
+
+ if args.output_file != "None":
+ os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
+ with open(args.output_file, "w") as f:
+ json.dump(stats, f, indent=2)
+ print(f"Results saved to {args.output_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vllm_offline_eagle3_vlm_batch.py b/tools/vllm_offline_eagle3_vlm_batch.py
index b9446e6b..32d1e0e0 100755
--- a/tools/vllm_offline_eagle3_vlm_batch.py
+++ b/tools/vllm_offline_eagle3_vlm_batch.py
@@ -37,7 +37,6 @@
from io import BytesIO
from datasets import load_dataset
-from PIL import Image
from vllm import LLM, SamplingParams
@@ -83,51 +82,17 @@ def parse_args():
return parser.parse_args()
-CAT_SHORT2LONG = {
- "acc": "Accounting",
- "agri": "Agriculture",
- "arch": "Architecture_and_Engineering",
- "art": "Art",
- "art_theory": "Art_Theory",
- "bas_med": "Basic_Medical_Science",
- "bio": "Biology",
- "chem": "Chemistry",
- "cli_med": "Clinical_Medicine",
- "cs": "Computer_Science",
- "design": "Design",
- "diag_med": "Diagnostics_and_Laboratory_Medicine",
- "econ": "Economics",
- "elec": "Electronics",
- "ep": "Energy_and_Power",
- "fin": "Finance",
- "geo": "Geography",
- "his": "History",
- "liter": "Literature",
- "manage": "Manage",
- "mark": "Marketing",
- "mate": "Materials",
- "math": "Math",
- "mech": "Mechanical_Engineering",
- "music": "Music",
- "phar": "Pharmacy",
- "phys": "Physics",
- "psy": "Psychology",
- "pub_health": "Public_Health",
- "socio": "Sociology",
-}
-
-
def main():
args = parse_args()
# Load dataset
print(f"Loading {args.dataset} dataset...")
if args.dataset == "MMMU/MMMU":
- ds = load_dataset("args.dataset", split="test", trust_remote_code=True)
+ ds = load_dataset(args.dataset, "History", split="test", trust_remote_code=True)
elif args.dataset == "Lin-Chen/MMStar":
ds = load_dataset(args.dataset, split="val", trust_remote_code=True)
- elif args.dataset == "hunyuan-ocr":
- ds = load_dataset("./dataset/hunyuan-ocr", split="test", trust_remote_code=True)
+ elif args.dataset == "opendatalab/OmniDocBench":
+ ds = load_dataset(args.dataset, split="train", trust_remote_code=True)
else:
ds = load_dataset(args.dataset, split="test", trust_remote_code=True)
if args.num_prompts is not None:
@@ -186,20 +151,16 @@ def main():
elif args.dataset == "HuggingFaceH4/MATH-500":
for item in ds:
prompts.append([{"role": "user", "content": item["problem"]}])
- elif args.dataset == "hunyuan-ocr":
+ elif args.dataset == "opendatalab/OmniDocBench":
for item in ds:
- image_url = pil_to_base64(
- Image.open(
- os.path.join("./dataset/hunyuan-ocr/images", item["img_path"])
- )
- )
+ image_url = pil_to_base64(item["image"])
prompts.append(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
- {"type": "text", "text": item["question"]},
+ {"type": "text", "text": "提取并识别图片中的文本。"},
],
}
]
@@ -290,12 +251,10 @@ def main():
"answer": ds[i]["answer"],
}
)
- elif args.dataset == "hunyuan-ocr":
+ elif args.dataset == "opendatalab/OmniDocBench":
results_data.append(
{
- "question": ds[i]["question"],
"generated_text": generated_text,
- "answer": ds[i]["answer"],
}
)
@@ -303,11 +262,12 @@ def main():
len(output.outputs[0].token_ids) for output in outputs
)
+ output_throughput = total_num_output_tokens / total_time
metrics_info = {
"total_time": total_time,
"avg_time_per_sample": total_time / len(prompts) if prompts else 0,
"use_eagle": args.use_eagle,
- "output_throughput": total_num_output_tokens / total_time,
+ "output_throughput": output_throughput,
}
if args.use_eagle and speculative_config:
@@ -345,6 +305,7 @@ def main():
metrics_info["acceptance_rates"] = acceptance_rates
print(f"Mean acceptance length: {acceptance_length:.2f}")
+ print(f"output_throughput: {output_throughput:.2f} tokens/s")
print(f"acceptance rates: {acceptance_rates}")
except Exception as e:
print(f"Error getting metrics: {e}")