diff --git a/angelslim/compressor/speculative/train/configs/hunyuan_ocr-eagle3.json b/angelslim/compressor/speculative/train/configs/hunyuan_ocr-eagle3.json new file mode 100644 index 00000000..8af5b6b0 --- /dev/null +++ b/angelslim/compressor/speculative/train/configs/hunyuan_ocr-eagle3.json @@ -0,0 +1,34 @@ +{ + "architectures": [ + "Eagle3LlamaForCausalLM" + ], + "model_type": "llama", + "target_model_type": "hunyuan_vl", + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 120000, + "eod_token_id": 120020, + "eos_token_id": 120020, + "dtype": "bfloat16", + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "image_start_token_id": 120118, + "image_end_token_id": 120119, + "image_token_id": 120120, + "image_newline_token_id": 120121, + "initializer_range": 0.02, + "intermediate_size": 3584, + "max_position_embeddings": 32768, + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "use_cache": true, + "vocab_size": 120818, + "tie_word_embeddings": true, + "transformers_version": "4.57.1", + "draft_vocab_size": 32000, + "modal_type": "VLM" +} diff --git a/angelslim/compressor/speculative/train/configs/qwen3-vl-2b-eagle3-mrope.json b/angelslim/compressor/speculative/train/configs/qwen3-vl-2b-eagle3-mrope.json new file mode 100644 index 00000000..f45e4b88 --- /dev/null +++ b/angelslim/compressor/speculative/train/configs/qwen3-vl-2b-eagle3-mrope.json @@ -0,0 +1,43 @@ +{ + "architectures": [ + "Eagle3LlamaForCausalLM" + ], + "model_type": "llama", + "target_model_type": "qwen3_vl", + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "type": "default", + "rope_type": "default", + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ] + }, + "rope_theta": 5000000, + "use_cache": true, + "vocab_size": 151936, + "tie_word_embeddings": true, + "transformers_version": "4.57.1", + "image_token_id": 151655, + "video_token_id": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "draft_vocab_size": 32000, + "modal_type": "VLM" +} diff --git a/angelslim/compressor/speculative/train/configs/qwen3-vl-30b-a3b-eagle3-mrope.json b/angelslim/compressor/speculative/train/configs/qwen3-vl-30b-a3b-eagle3-mrope.json new file mode 100644 index 00000000..0ae87e8b --- /dev/null +++ b/angelslim/compressor/speculative/train/configs/qwen3-vl-30b-a3b-eagle3-mrope.json @@ -0,0 +1,43 @@ +{ + "architectures": [ + "Eagle3LlamaForCausalLM" + ], + "model_type": "llama", + "target_model_type": "qwen3_vl", + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "type": "default", + "rope_type": "default", + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ] + }, + "rope_theta": 5000000, + "use_cache": true, + "vocab_size": 151936, + "tie_word_embeddings": false, + "transformers_version": "4.57.1", + "image_token_id": 151655, + "video_token_id": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "draft_vocab_size": 32000, + "modal_type": "VLM" +} diff --git a/angelslim/compressor/speculative/train/configs/qwen3-vl-4b-eagle3-mrope.json b/angelslim/compressor/speculative/train/configs/qwen3-vl-4b-eagle3-mrope.json index 64216957..c40b3c75 100644 --- a/angelslim/compressor/speculative/train/configs/qwen3-vl-4b-eagle3-mrope.json +++ b/angelslim/compressor/speculative/train/configs/qwen3-vl-4b-eagle3-mrope.json @@ -3,6 +3,7 @@ "Eagle3LlamaForCausalLM" ], "model_type": "llama", + "target_model_type": "qwen3_vl", "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 151643, diff --git a/angelslim/compressor/speculative/train/data/chat_templates.py b/angelslim/compressor/speculative/train/data/chat_templates.py index b476c998..3cfac781 100644 --- a/angelslim/compressor/speculative/train/data/chat_templates.py +++ b/angelslim/compressor/speculative/train/data/chat_templates.py @@ -26,18 +26,22 @@ class ChatTemplateType(Enum): """Supported chat template types.""" + QWEN2_AUDIO = "qwen2_audio" QWEN3 = "qwen3" HUNYUAN = "hunyuan" QWEN3_VL = "qwen3_vl" HUNYUAN_7B = "hunyuan_7b" + HUNYUAN_VL = "hunyuan_vl" # String to ChatTemplateType mapping CHAT_TEMPLATE_TYPE_MAPPING = { + "qwen2_audio": ChatTemplateType.QWEN2_AUDIO, "qwen3": ChatTemplateType.QWEN3, "hunyuan": ChatTemplateType.HUNYUAN, "hunyuan_7b": ChatTemplateType.HUNYUAN_7B, "qwen3_vl": ChatTemplateType.QWEN3_VL, + "hunyuan_vl": ChatTemplateType.HUNYUAN_VL, } @@ -133,6 +137,26 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]: } ], ), + ChatTemplateType.QWEN2_AUDIO: ChatTemplate( + user_header="<|im_start|>user\n", + assistant_header="<|im_start|>assistant\n", + system_prompt=[ + { + "type": "text", + "text": ("You are a helpful assistant."), + } + ], + ), + ChatTemplateType.HUNYUAN_VL: ChatTemplate( + user_header="<|hy_Assistant|>", + assistant_header="<|hy_User|>", + system_prompt=[ + { + "type": "text", + "text": "", + } + ], + ), } def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate: diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index e94378f1..799e0933 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -130,6 +130,20 @@ def paddingtensor3D_CBN(tensor_list): return torch.cat(out_tensor_list, dim=1) +def paddingtensor3D_BCN(tensor_list): + if all(tensor is None for tensor in tensor_list): + return None + N = max(tensor.shape[-1] for tensor in tensor_list if tensor is not None) + out_tensor_list = [] + for tensor in tensor_list: + b, c, n = tensor.shape + outtensor = torch.zeros(b, c, N, dtype=tensor_list[0].dtype) + if tensor is not None: + outtensor[:, :, :n] = tensor + out_tensor_list.append(outtensor) + return torch.cat(out_tensor_list, dim=0) + + def paddingtensor3D_BHW(tensor_list): if all(tensor is None for tensor in tensor_list): return None @@ -240,11 +254,90 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: batch["target_hiddens"] = torch.cat( [paddingtensor(item["target_hiddens"], max_length) for item in features] ) + if all( + "inputs_embeds" in item and item["inputs_embeds"] is not None + for item in features + ): batch["inputs_embeds"] = torch.cat( [paddingtensor(item["inputs_embeds"], max_length) for item in features] ) + if all( + "position_ids" in item and item["position_ids"] is not None + for item in features + ): batch["position_ids"] = paddingtensor3D_CBN( [item["position_ids"] for item in features] ) return batch + + +class VLMHunyuanDataCollatorWithPadding: + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + max_length = max(item["input_ids"].shape[1] for item in features) + batch_input_ids = torch.cat( + [paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [paddingtensor2D(item["attention_mask"], max_length) for item in features] + ) + batch_loss_mask = torch.cat( + [paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "hidden_states": None, + "target_hiddens": None, + "inputs_embeds": None, + "position_ids": None, + "input_position_ids": None, + } + + if "pixel_values" in features[0]: + batch["pixel_values"] = paddingtensor3D_BHW( + [item["pixel_values"] for item in features] + ) + + if all( + "image_grid_thw" in item and item["image_grid_thw"] is not None + for item in features + ): + batch["image_grid_thw"] = torch.cat( + [item["image_grid_thw"] for item in features], dim=0 + ) + + # Check if both hidden_states and target_hiddens exist in all features + if all( + "hidden_states" in item and "target_hiddens" in item for item in features + ): + batch["hidden_states"] = torch.cat( + [paddingtensor(item["hidden_states"], max_length) for item in features] + ) + batch["target_hiddens"] = torch.cat( + [paddingtensor(item["target_hiddens"], max_length) for item in features] + ) + if all( + "inputs_embeds" in item and item["inputs_embeds"] is not None + for item in features + ): + batch["inputs_embeds"] = torch.cat( + [paddingtensor(item["inputs_embeds"], max_length) for item in features] + ) + if all( + "input_position_ids" in item and item["input_position_ids"] is not None + for item in features + ): + batch["input_position_ids"] = paddingtensor3D_BCN( + [item["input_position_ids"] for item in features] + ) + if all( + "position_ids" in item and item["position_ids"] is not None + for item in features + ): + batch["position_ids"] = torch.cat( + [paddingtensor2D(item["position_ids"], max_length) for item in features] + ) + return batch diff --git a/angelslim/compressor/speculative/train/data/dataset.py b/angelslim/compressor/speculative/train/data/dataset.py index cfcb1b51..cd775fd6 100644 --- a/angelslim/compressor/speculative/train/data/dataset.py +++ b/angelslim/compressor/speculative/train/data/dataset.py @@ -40,6 +40,7 @@ def __init__( chat_template_type: Optional[Union[str, ChatTemplateType]] = None, display: bool = False, cache_in_memory: bool = False, + target_model_type: Optional[str] = None, ): """ Initialize DatasetManager with DataArguments. @@ -60,6 +61,7 @@ def __init__( self.model_max_length = model_max_length self.display = display self.cache_in_memory = cache_in_memory + self.target_model_type = target_model_type # Convert chat_template_type to ChatTemplateType enum if chat_template_type is None: @@ -76,6 +78,7 @@ def __init__( self.online_dataset_builder = DatasetBuilderFactory.create( training_mode="online", modal_type=data_args.modal_type, + target_model_type=self.target_model_type, tokenizer=tokenizer, max_length=model_max_length, shuffle_seed=data_args.shuffle_seed, @@ -86,6 +89,7 @@ def __init__( self.offline_dataset_builder = DatasetBuilderFactory.create( training_mode="offline", modal_type=data_args.modal_type, + target_model_type=self.target_model_type, cache_in_memory=cache_in_memory, ) diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py index 79cef188..dd208cfe 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py @@ -13,13 +13,23 @@ # limitations under the License. from .dataset_builder_factory import DatasetBuilderFactory -from .offline_dataset_builder import OfflineLLMDatasetBuilder, OfflineVLMDatasetBuilder -from .online_dataset_builder import OnlineLLMDatasetBuilder, OnlineVLMDatasetBuilder +from .offline_dataset_builder import ( + OfflineLLMDatasetBuilder, + OfflineVLMDatasetBuilder, + OfflineVLMHunyuanVLDatasetBuilder, +) +from .online_dataset_builder import ( + OnlineLLMDatasetBuilder, + OnlineVLMDatasetBuilder, + OnlineVLMHunyuanVLDatasetBuilder, +) __all__ = [ "OnlineLLMDatasetBuilder", "OnlineVLMDatasetBuilder", + "OnlineVLMHunyuanVLDatasetBuilder", "OfflineLLMDatasetBuilder", "OfflineVLMDatasetBuilder", + "OfflineVLMHunyuanVLDatasetBuilder", "DatasetBuilderFactory", ] diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/dataset_builder_factory.py b/angelslim/compressor/speculative/train/data/dataset_builder/dataset_builder_factory.py index 7133298d..472ff537 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/dataset_builder_factory.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/dataset_builder_factory.py @@ -26,34 +26,44 @@ class DatasetBuilderFactory: @classmethod def register( - cls, training_mode: str = "online", modal_type: str = "LLM" + cls, + training_mode: str = "online", + modal_type: str = "LLM", + target_model_type: str = None, ) -> Callable[[Type[DatasetBuilder]], Type[DatasetBuilder]]: """Decorator to register dataset builders.""" def decorator(builder_cls: Type[DatasetBuilder]) -> Type[DatasetBuilder]: - if (training_mode, modal_type) in cls._builders: + if (training_mode, modal_type, target_model_type) in cls._builders: print( f"DatasetBuilder for training_mode '{training_mode}'" f" modal_type '{modal_type}' already exists." + f" target_model_type '{target_model_type}' already exists." ) - cls._builders[(training_mode, modal_type)] = builder_cls + cls._builders[(training_mode, modal_type, target_model_type)] = builder_cls return builder_cls return decorator @classmethod def create( - cls, training_mode: str = "online", modal_type: str = "LLM", **kwargs: Any + cls, + training_mode: str = "online", + modal_type: str = "LLM", + target_model_type: str = None, + **kwargs: Any, ) -> DatasetBuilder: """Create a dataset builder instance based on training_mode and modal_type.""" - if (training_mode, modal_type) not in cls._builders: + if (training_mode, modal_type, target_model_type) not in cls._builders: available = list(cls._builders.keys()) raise ValueError( - f"Unknown training_mode '{training_mode}'" - f" modal_type '{modal_type}'. Available: {available}" + f"Unknown training_mode '{training_mode}' " + f"modal_type '{modal_type}' " + f"target_model_type '{target_model_type}'. " + f"Available: {available}" ) - builder_class = cls._builders[(training_mode, modal_type)] + builder_class = cls._builders[(training_mode, modal_type, target_model_type)] return builder_class(**kwargs) @classmethod diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py index f45a8a6a..4a84c766 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py @@ -21,7 +21,11 @@ from angelslim.utils import rank0_print -from ..data_utils import DataCollatorWithPadding, VLMDataCollatorWithPadding +from ..data_utils import ( + DataCollatorWithPadding, + VLMDataCollatorWithPadding, + VLMHunyuanDataCollatorWithPadding, +) from .base_dataset_builder import DatasetBuilder from .dataset_builder_factory import DatasetBuilderFactory @@ -230,7 +234,7 @@ def _load_ckpt(self, idx: int) -> Optional[Dict[str, torch.Tensor]]: "target_hiddens", # B, N, D "hidden_states", # B, N, 3*D "loss_mask", # B, N - "inputs_embeds", # B, N, D + # "inputs_embeds", # B, N, D "position_ids", # 3, B, N ] missing_keys = [key for key in required_keys if key not in data] @@ -282,7 +286,7 @@ def get_data_collator(self) -> Any: return DataCollatorWithPadding() -@DatasetBuilderFactory.register("offline", "VLM") +@DatasetBuilderFactory.register("offline", "VLM", "qwen3_vl") class OfflineVLMDatasetBuilder(DatasetBuilder): def __init__( self, file_pattern: str = "*.ckpt", cache_in_memory: bool = False, **kwargs: Any @@ -302,3 +306,25 @@ def build_dataset(self, datapath: str, **kwargs: Any) -> Dataset: def get_data_collator(self) -> Any: return VLMDataCollatorWithPadding() + + +@DatasetBuilderFactory.register("offline", "VLM", "hunyuan_vl") +class OfflineVLMHunyuanVLDatasetBuilder(DatasetBuilder): + def __init__( + self, file_pattern: str = "*.ckpt", cache_in_memory: bool = False, **kwargs: Any + ): + self.file_pattern = file_pattern + self.cache_in_memory = cache_in_memory + + def build_dataset(self, datapath: str, **kwargs: Any) -> Dataset: + """ + Create offline datasets from pre-computed .ckpt files. + """ + return OfflineVLMEagle3Dataset( + data_dir=datapath, + file_pattern=self.file_pattern, + cache_in_memory=self.cache_in_memory, + ) + + def get_data_collator(self) -> Any: + return VLMHunyuanDataCollatorWithPadding() diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py index f8f9417c..d25c8711 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py @@ -16,13 +16,18 @@ import torch from datasets import Features, Value, load_dataset +from PIL import Image from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer from angelslim.utils import rank0_print from ..chat_templates import ChatTemplateType -from ..data_utils import DataCollatorWithPadding, VLMDataCollatorWithPadding +from ..data_utils import ( + DataCollatorWithPadding, + VLMDataCollatorWithPadding, + VLMHunyuanDataCollatorWithPadding, +) from .base_dataset_builder import OnlineDatasetBuilder from .dataset_builder_factory import DatasetBuilderFactory @@ -50,7 +55,7 @@ def get_data_collator(self) -> Any: return DataCollatorWithPadding() -@DatasetBuilderFactory.register("online", "VLM") +@DatasetBuilderFactory.register("online", "VLM", "qwen3_vl") class OnlineVLMDatasetBuilder(OnlineDatasetBuilder): def __init__( self, @@ -284,3 +289,253 @@ def _process_single_conversation( except Exception as e: rank0_print(f"Error processing conversation: {e}") return None + + +@DatasetBuilderFactory.register("online", "VLM", "hunyuan_vl") +class OnlineVLMHunyuanVLDatasetBuilder(OnlineDatasetBuilder): + def __init__( + self, + tokenizer: Union[AutoTokenizer, AutoProcessor], + max_length: int = 2048, + shuffle_seed: int = 42, + chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3, + display: bool = False, + **kwargs: Any, + ): + super().__init__( + tokenizer, + max_length, + shuffle_seed, + chat_template_type, + display, + ) + + def build_dataset( + self, + datapath: str, + num_proc: int = 8, + shuffle: bool = True, + sample_num: Optional[int] = None, + ) -> Dataset: + try: + # Load dataset + features = Features( + { + "id": Value("string"), + "conversations": [ + { + "role": Value("string"), + "content": [ + { + "type": Value("string"), + "text": Value("string"), + "image": Value("string"), + } + ], + } + ], + } + ) + ds = load_dataset("json", data_files=datapath, features=features) + + # Conditionally shuffle dataset + if shuffle: + ds = ds["train"].shuffle(seed=self.shuffle_seed) + else: + ds = ds["train"] + if sample_num is not None and 0 < sample_num < len(ds): + ds = ds.select(range(sample_num)) + + # Store original columns for removal + original_columns = ds.column_names + + # Apply preprocessing + processed_ds = ds.map( + self._preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + load_from_cache_file=False, + desc="Processing conversations", + ) + + # Filter out None results with multiprocessing support + processed_ds = processed_ds.filter( + lambda batch: [ids is not None for ids in batch["input_ids"]], + batched=True, + num_proc=num_proc, + desc="Filtering empty input_ids", + ) + processed_ds.set_format(type="torch") + + return processed_ds + + except Exception as e: + raise RuntimeError(f"Dataset building failed for {datapath}") from e + + def get_data_collator(self) -> Any: + return VLMHunyuanDataCollatorWithPadding() + + def _preprocess_function(self, examples: Dict[str, List]) -> Dict[str, List]: + new_examples = { + "input_ids": [], + "attention_mask": [], + "loss_mask": [], + "pixel_values": [], + "image_grid_thw": [], + "position_ids": [], + "input_position_ids": [], + } + for i in range(len(examples["id"])): + try: + processed_example = self._process_single_conversation( + examples["conversations"][i] + ) + if processed_example is not None: + for key in new_examples.keys(): + if key not in processed_example: + new_examples[key].append(None) + else: + new_examples[key].append(processed_example[key]) + + except Exception as e: + rank0_print(f"Error processing example: {e}") + # Add None placeholders to maintain batch consistency + for key in new_examples: + new_examples[key].append(None) + cleaned_new_examples = {} + for key, value in new_examples.items(): + if any(v is not None for v in value): + cleaned_new_examples[key] = value + return new_examples + + def _visualize_loss_mask( + self, input_ids: torch.Tensor, loss_mask: torch.Tensor, conversation: str + ) -> None: + """ + Visualize loss_mask with color-coded output. + + Args: + input_ids: Token IDs + loss_mask: Loss mask tensor (1 for training, 0 for ignoring) + conversation: Original conversation text + """ + input_ids = input_ids.view(-1) + return super()._visualize_loss_mask(input_ids, loss_mask, conversation) + + def _create_loss_mask_from_offsets( + self, conversation: str, offsets: torch.Tensor + ) -> torch.Tensor: + if offsets.ndim == 3: + offsets = offsets[0] + return super()._create_loss_mask_from_offsets(conversation, offsets) + + def _process_single_conversation( + self, conversation_data: List[Dict] + ) -> Optional[Dict]: + if not conversation_data or not isinstance(conversation_data, list): + return None + + try: + for message in conversation_data: + # adapt to hunyuan_vl + if message["role"] == "assistant" or message["role"] == "system": + message["content"] = message["content"][0]["text"] + else: + for content in message["content"]: + if "image" in content and content["image"] is None: + content.pop("image") + if "text" in content and content["text"] is None: + content.pop("text") + + # Build messages with system prompt + messages = self._build_messages(conversation_data) + if not messages: + return None + + text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + image_inputs, _ = self._extract_vision_info(messages) + + encoding = self.tokenizer( + text=[text], + images=image_inputs, + return_tensors="pt", + return_offsets_mapping=True, + max_length=self.max_length, + truncation=True, + padding=False, + ) + input_ids = encoding["input_ids"] + offsets = encoding["offset_mapping"] + input_position_ids = encoding["position_ids"] + conversation = self.tokenizer.decode( + input_ids[0], skip_special_tokens=False + ) + + # Create loss mask for assistant responses + try: + # loss_mask = torch.tensor(conversation_data['loss_mask']).unsqueeze(0) + loss_mask = self._create_loss_mask_from_offsets(conversation, offsets) + except Exception as e: + rank0_print(f"Error creating loss mask: {e}") + rank0_print(f"offsets: {offsets}") + raise e + attention_mask = torch.ones_like(input_ids) + + # Visualize loss mask if display mode is enabled + if self.display and self.display_count == 0: + try: + self._visualize_loss_mask(input_ids, loss_mask, conversation) + except Exception as e: + rank0_print(f"Error visualizing loss mask: {e}") + rank0_print(f"input_ids: {input_ids}, loss_mask: {loss_mask}") + raise e + self.display_count += 1 + + result_dict = { + "input_ids": input_ids.view(1, -1), + "attention_mask": attention_mask.view(1, -1), + "loss_mask": loss_mask.view(1, -1), + "input_position_ids": input_position_ids, + } + + if "pixel_values" in encoding: + result_dict["pixel_values"] = encoding["pixel_values"].unsqueeze(0) + if "image_grid_thw" in encoding: + result_dict["image_grid_thw"] = encoding["image_grid_thw"] + + return result_dict + + except Exception as e: + rank0_print(f"Error processing conversation: {e}") + return None + + def _extract_vision_info(self, messages: List[Dict]) -> tuple: + """Extract image and video paths from messages""" + image_paths = [] + video_paths = [] + + for message in messages: + content = message.get("content", []) + if not isinstance(content, list): + continue + + for item in content: + if item.get("type") == "image": + # Handle both file paths and PIL images + if isinstance(item["image"], str): + try: + img = Image.open(item["image"]) + image_paths.append(img) + except ValueError as e: + raise ValueError( + f"Could not open image file: {item['image']}, {e}" + ) + elif isinstance(item["image"], Image.Image): + image_paths.append(item["image"]) + elif item.get("type") == "video": + video_paths.append(item["video"]) + + return image_paths, video_paths diff --git a/angelslim/compressor/speculative/train/models/draft/base_model.py b/angelslim/compressor/speculative/train/models/draft/base_model.py index a6549d90..e07eee75 100644 --- a/angelslim/compressor/speculative/train/models/draft/base_model.py +++ b/angelslim/compressor/speculative/train/models/draft/base_model.py @@ -113,19 +113,19 @@ def _load_from_safetensors( ): """Load embedding weights from safetensors format.""" try: - index_file = os.path.join(model_path, "model.safetensors.index.json") - if not os.path.exists(index_file): - return None - - with open(index_file, "r") as f: - index_json = json.load(f) - - if embed_weight_key in index_json["weight_map"]: - emb_path = index_json["weight_map"][embed_weight_key] - else: - raise KeyError("Embedding weights key not found in index.") - - safetensors_file = os.path.join(model_path, emb_path) + try: + index_file = os.path.join(model_path, "model.safetensors.index.json") + if not os.path.exists(index_file): + raise KeyError("no model.safetensors.index.json !") + with open(index_file, "r") as f: + index_json = json.load(f) + if embed_weight_key in index_json["weight_map"]: + emb_path = index_json["weight_map"][embed_weight_key] + else: + raise KeyError("Embedding weights key not found in index.") + safetensors_file = os.path.join(model_path, emb_path) + except Exception: + safetensors_file = os.path.join(model_path, "model.safetensors") with safe_open(safetensors_file, framework="pt", device="cpu") as f: tensor_slice = f.get_slice(embed_weight_key) @@ -142,19 +142,22 @@ def _load_from_pytorch_bin( ): """Load embedding weights from pytorch_model.bin format.""" try: - index_file = os.path.join(model_path, "pytorch_model.bin.index.json") - if not os.path.exists(index_file): - return None - - with open(index_file, "r") as f: - index_json = json.load(f) - - if embed_weight_key in index_json["weight_map"]: - emb_path = index_json["weight_map"][embed_weight_key] - else: - raise KeyError("Embedding weights key not found in index.") - - bin_file = os.path.join(model_path, emb_path) + try: + index_file = os.path.join(model_path, "pytorch_model.bin.index.json") + if not os.path.exists(index_file): + raise KeyError("no pytorch_model.bin.index.json !") + with open(index_file, "r") as f: + index_json = json.load(f) + + if embed_weight_key in index_json["weight_map"]: + emb_path = index_json["weight_map"][embed_weight_key] + else: + raise KeyError("Embedding weights key not found in index.") + + bin_file = os.path.join(model_path, emb_path) + + except Exception: + bin_file = os.path.join(model_path, "pytorch_model.bin") weights = torch.load(bin_file, map_location="cpu") tensor = weights[embed_weight_key].float() diff --git a/angelslim/compressor/speculative/train/models/target/target_head.py b/angelslim/compressor/speculative/train/models/target/target_head.py index f073c99a..5296ed79 100644 --- a/angelslim/compressor/speculative/train/models/target/target_head.py +++ b/angelslim/compressor/speculative/train/models/target/target_head.py @@ -96,8 +96,12 @@ def from_pretrained( ) # Get model dimensions - hidden_size = config.hidden_size - vocab_size = config.vocab_size + if config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + hidden_size = config.text_config.hidden_size + vocab_size = config.text_config.vocab_size + else: + hidden_size = config.hidden_size + vocab_size = config.vocab_size # Initialize lm_head lm_head = nn.Linear(hidden_size, vocab_size, bias=False) @@ -105,24 +109,29 @@ def from_pretrained( # Load lm_head weights from safetensors try: # Read safetensors index to locate lm_head weights - index_path = os.path.join( - model_name_or_path, "model.safetensors.index.json" - ) - - if not os.path.exists(index_path): - raise FileNotFoundError( - f"model.safetensors.index.json not found in {model_name_or_path}. " - "Please ensure the model is saved in safetensors " - "format with sharding." + try: + index_path = os.path.join( + model_name_or_path, "model.safetensors.index.json" ) - # Model is sharded, use index to find lm_head - with open(index_path, "r") as f: - index_json = json.loads(f.read()) - head_path = index_json["weight_map"][lm_head_key] + if not os.path.exists(index_path): + raise FileNotFoundError( + "model.safetensors.index.json" + f"not found in {model_name_or_path}. " + "Please ensure the model is saved in safetensors " + "format with sharding." + ) + + # Model is sharded, use index to find lm_head + with open(index_path, "r") as f: + index_json = json.loads(f.read()) + head_path = index_json["weight_map"][lm_head_key] + + # Load lm_head weights using safetensors + safetensors_file = os.path.join(model_name_or_path, head_path) + except Exception: + safetensors_file = os.path.join(model_name_or_path, "model.safetensors") - # Load lm_head weights using safetensors - safetensors_file = os.path.join(model_name_or_path, head_path) with safe_open(safetensors_file, framework="pt", device="cpu") as f: tensor_slice = f.get_slice(lm_head_key) _, hidden_dim = tensor_slice.get_shape() diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 3cbc675c..a7bf2d43 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -132,7 +132,16 @@ def _extract_auxiliary_hidden_states( Concatenated hidden states, shape [batch_size, seq_len, hidden_size * 3] """ if aux_layer_ids is None: - aux_layer_ids = self._get_default_aux_layer_ids(len(hidden_states)) + if hasattr(self.model.config, "num_hidden_layers"): + num_layers = self.model.config.num_hidden_layers + elif hasattr(self.model.config.text_config, "num_hidden_layers"): + num_layers = self.model.config.text_config.num_hidden_layers + else: + raise ValueError( + "Failed to set aux hidden states layers as model config. " + f"{self.model.config} does not have num_hidden_layers" + ) + aux_layer_ids = self._get_default_aux_layer_ids(num_layers) # Offset by 1 to skip embedding layer embed_offset = 1 @@ -278,26 +287,47 @@ class VLMTransformersBackend(BaseBackend): """VLM HuggingFace Transformers backend""" def load_model(self): - from transformers import AutoModelForImageTextToText, AutoProcessor + if self.target_model_type == "hunyuan_vl": + from transformers import AutoProcessor, HunYuanVLForConditionalGeneration - device = decide_device_for_distributed() - print_with_rank(f"Loading model to device: {device}") + device = decide_device_for_distributed() + print_with_rank(f"Loading model to device: {device}") - # Prepare model loading configuration - model_kwargs = self._prepare_model_kwargs(device) + # Prepare model loading configuration + model_kwargs = self._prepare_model_kwargs(device) - self.model = AutoModelForImageTextToText.from_pretrained( - self.model_path, **model_kwargs - ) + self.model = HunYuanVLForConditionalGeneration.from_pretrained( + self.model_path, **model_kwargs + ) + self.model.eval() - # Freeze the base model - for param in self.model.parameters(): - param.requires_grad = False - self.model.eval() + # Load processor + self.tokenizer = AutoProcessor.from_pretrained( + self.model_path, trust_remote_code=True + ) + else: + from transformers import AutoModelForImageTextToText, AutoProcessor - self.tokenizer = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True - ) + device = decide_device_for_distributed() + print_with_rank(f"Loading model to device: {device}") + + # Prepare model loading configuration + model_kwargs = self._prepare_model_kwargs(device) + + self.model = AutoModelForImageTextToText.from_pretrained( + self.model_path, **model_kwargs + ) + + # Freeze the base model + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + # Load processor + self.tokenizer = AutoProcessor.from_pretrained( + self.model_path, + trust_remote_code=True, + ) def _prepare_model_kwargs(self, device: str) -> dict: """ @@ -345,16 +375,24 @@ def hook(module, args, kwargs): position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) return args, kwargs - handle = self.model.language_model.register_forward_pre_hook( - hook, with_kwargs=True - ) - + if self.target_model_type == "qwen3_vl": + handle = self.model.language_model.register_forward_pre_hook( + hook, with_kwargs=True + ) + elif self.target_model_type == "hunyuan_vl": + handle = self.model.model.register_forward_pre_hook(hook, with_kwargs=True) + else: + raise ValueError(f"Unsupported target model type: {self.target_model_type}") pixel_values = kwargs.get("pixel_values", None) + if pixel_values is not None: + pixel_values = pixel_values.squeeze(0) image_grid_thw = kwargs.get("image_grid_thw", None) + input_position_ids = kwargs.get("input_position_ids", None) with torch.no_grad(): outputs = self.model( input_ids, attention_mask=attention_mask, + position_ids=input_position_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, output_hidden_states=True, @@ -362,8 +400,20 @@ def hook(module, args, kwargs): ) handle.remove() - inputs_embeds = inputs_embeds_list[0].to(input_ids.device) - position_ids = position_ids_list[0].to(input_ids.device) + inputs_embeds = ( + inputs_embeds_list[0].to(input_ids.device) if inputs_embeds_list else None + ) + + if self.target_model_type == "hunyuan_vl": + position_ids = ( + position_ids_list[0][:, 0, :].to(input_ids.device) + if position_ids_list + else None + ) + else: + position_ids = ( + position_ids_list[0].to(input_ids.device) if position_ids_list else None + ) # Extract auxiliary hidden states aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) @@ -407,16 +457,25 @@ def hook(module, args, kwargs): position_ids_list.append(kwargs["position_ids"].clone().detach().cpu()) return args, kwargs - handle = self.model.language_model.register_forward_pre_hook( - hook, with_kwargs=True - ) + if self.target_model_type == "qwen3_vl": + handle = self.model.language_model.register_forward_pre_hook( + hook, with_kwargs=True + ) + elif self.target_model_type == "hunyuan_vl": + handle = self.model.model.register_forward_pre_hook(hook, with_kwargs=True) + else: + raise ValueError(f"Unsupported target model type: {self.target_model_type}") pixel_values = kwargs.get("pixel_values", None) + if pixel_values is not None: + pixel_values = pixel_values.squeeze(0) image_grid_thw = kwargs.get("image_grid_thw", None) + input_position_ids = kwargs.get("input_position_ids", None) with torch.no_grad(): outputs = self.model( input_ids, pixel_values=pixel_values, + position_ids=input_position_ids, image_grid_thw=image_grid_thw, attention_mask=attention_mask, output_hidden_states=True, @@ -424,9 +483,19 @@ def hook(module, args, kwargs): ) handle.remove() - inputs_embeds = inputs_embeds_list[0].to(input_ids.device) - position_ids = position_ids_list[0].to(input_ids.device) - + inputs_embeds = ( + inputs_embeds_list[0].to(input_ids.device) if inputs_embeds_list else None + ) + if self.target_model_type == "hunyuan_vl": + position_ids = ( + position_ids_list[0][:, 0, :].to(input_ids.device) + if position_ids_list + else None + ) + else: + position_ids = ( + position_ids_list[0].to(input_ids.device) if position_ids_list else None + ) # Extract auxiliary hidden states aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) aux_hidden_states = self._extract_auxiliary_hidden_states( @@ -478,7 +547,12 @@ class TargetModelWrapper: } def __init__( - self, model_path: str, modal_type: str = "LLM", backend: str = "hf", **kwargs + self, + model_path: str, + modal_type: str = "LLM", + backend: str = "hf", + target_model_type: str = None, + **kwargs, ): """ Initialize TargetModel with specified backend @@ -496,6 +570,7 @@ def __init__( self.backend_name = backend self.backend = self.BACKENDS[(backend, modal_type)](model_path, **kwargs) + self.backend.target_model_type = target_model_type self.backend.load_model() def get_hidden_states_and_logits( @@ -583,6 +658,7 @@ def create_target_model( model_path: str, torch_dtype: torch.dtype = torch.bfloat16, trust_remote_code: bool = True, + target_model_type: str = None, **extra_kwargs, ) -> TargetModelWrapper: """ @@ -626,4 +702,9 @@ def create_target_model( f"Use one of: {list(TargetModelWrapper.BACKENDS.keys())}" ) - return TargetModelWrapper(backend=backend, model_path=model_path, **kwargs) + return TargetModelWrapper( + backend=backend, + model_path=model_path, + target_model_type=target_model_type, + **kwargs, + ) diff --git a/angelslim/compressor/speculative/train/trainer/offline_eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/offline_eagle3_trainer.py index ad4235f2..f0db3dc2 100644 --- a/angelslim/compressor/speculative/train/trainer/offline_eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/offline_eagle3_trainer.py @@ -136,7 +136,6 @@ def prepare_data_for_draft_model( - hidden_states: Pre-computed aux hidden states from target model - attention_mask: Attention mask - loss_mask: Mask for loss computation - - inputs_embeds: Input embeddings (text and visual) - position_ids (optional): Position IDs (3D for VLMs mrope) Returns: @@ -148,7 +147,6 @@ def prepare_data_for_draft_model( "hidden_states", "attention_mask", "loss_mask", - "inputs_embeds", "position_ids", ] output_fields = [ @@ -157,14 +155,14 @@ def prepare_data_for_draft_model( "hidden_states", "attention_mask", "loss_mask", - "inputs_embeds", "position_ids", ] - target_logits = self.target_head(inputs["target_hiddens"]) + target_logits = self.target_head( + inputs["target_hiddens"].to(self.target_head.lm_head.weight.dtype) + ) loss_mask = inputs["loss_mask"] input_ids = inputs["input_ids"] - # inputs_embeds = inputs.get("inputs_embeds", None) position_ids = inputs.get("position_ids", None) # Apply right padding and move tensors to correct device @@ -176,7 +174,6 @@ def prepare_data_for_draft_model( outputs["target_logits"] = target_logits outputs["loss_mask"] = loss_mask outputs["input_ids"] = input_ids - # outputs["inputs_embeds"] = inputs_embeds outputs["position_ids"] = position_ids return outputs diff --git a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py index 4185a285..74af8c8c 100644 --- a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py @@ -119,12 +119,18 @@ def prepare_data_for_draft_model(self, inputs): attention_mask = inputs["attention_mask"] loss_mask = inputs["loss_mask"] + kwargs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "attention_mask", "loss_mask"] + } # Get hidden states and logits from target model hidden_states, target_logits, _, position_ids = ( self.target_model.get_hidden_states_and_logits( input_ids=input_ids, attention_mask=attention_mask, aux_hidden_states_layer_ids=self._aux_hidden_states_layer_ids, + **kwargs, ) ) diff --git a/scripts/speculative/hunyuan_ocr/generate_vlm_hidden_for_draft_model.sh b/scripts/speculative/hunyuan_ocr/generate_vlm_hidden_for_draft_model.sh new file mode 100644 index 00000000..34fced23 --- /dev/null +++ b/scripts/speculative/hunyuan_ocr/generate_vlm_hidden_for_draft_model.sh @@ -0,0 +1,20 @@ +# #!/bin/bash + +DATASET_PATH= +MODEL_NAME=tencent/HunyuanOCR +TARGET_BACKEND=hf +MODEL_MAX_LENGTH=8192 +CHAT_TEMPLATE_TYPE=hunyuan_vl +OUTPUT_DIR= +echo $DATASET_PATH +echo $OUTPUT_DIR +torchrun --nproc_per_node=8 tools/generate_hidden_for_draft_model.py \ + --modal_type VLM \ + --dataset_path $DATASET_PATH \ + --model_name $MODEL_NAME \ + --target_backend $TARGET_BACKEND \ + --torch_dtype bfloat16 \ + --model_max_length $MODEL_MAX_LENGTH \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --outdir $OUTPUT_DIR \ + --target_model_type hunyuan_vl diff --git a/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh new file mode 100644 index 00000000..8bfb2058 --- /dev/null +++ b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH=tencent/HunyuanOCR +export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/hunyuan_ocr-eagle3.json +export TRAIN_DATA_PATH= +export TRAIN_HIDDEN_PATH= +export EVAL_HIDDEN_PATH= +export OUTPUT_DIR= +export RUN_NAME=hunyuan-ocr-eagle3-angelslim +export MODEL_MAX_LENGTH=8192 +export LM_HEAD_KEY="model.embed_tokens.weight" +export CHAT_TEMPLATE_TYPE=hunyuan_vl +export EMBED_WEIGHT_KEY="model.embed_tokens.weight" + +torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ + --modal_type VLM \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --train_hidden_path $TRAIN_HIDDEN_PATH \ + --eval_hidden_path $EVAL_HIDDEN_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 20 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --save_strategy "steps" \ + --eval_strategy "steps" \ + --save_steps 10000 \ + --eval_steps 20000 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 100 \ + --model_max_length $MODEL_MAX_LENGTH \ + --lm_head_key $LM_HEAD_KEY \ + --embed_weight_key $EMBED_WEIGHT_KEY \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --report_to wandb \ + --run_name $RUN_NAME \ + --num_proc 8 \ + --training_time_test_length 4 \ + --bf16 diff --git a/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_online.sh b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_online.sh new file mode 100644 index 00000000..7c685252 --- /dev/null +++ b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_online.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH=tencent/HunyuanOCR +export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/hunyuan_ocr-eagle3.json +export TRAIN_DATA_PATH= +export EVAL_DATA_PATH= +export OUTPUT_DIR= +export EMBED_WEIGHT_KEY="model.embed_tokens.weight" +export MODEL_MAX_LENGTH=8192 +export CHAT_TEMPLATE_TYPE=hunyuan_vl +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +torchrun --nproc_per_node=8 tools/train_eagle3_online.py \ + --modal_type VLM \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --eval_data_path $EVAL_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 20 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --num_proc 8 \ + --save_strategy "steps" \ + --save_steps 1000 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 20 \ + --model_max_length $MODEL_MAX_LENGTH \ + --embed_weight_key $EMBED_WEIGHT_KEY \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --report_to none \ + --run_name hunyuan-ocr-eagle3-angelslim \ No newline at end of file diff --git a/scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model.sh b/scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model.sh new file mode 100644 index 00000000..e43732e3 --- /dev/null +++ b/scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +DATASET_PATH= +MODEL_NAME=Qwen/Qwen3-VL-4B-Instruct +TARGET_BACKEND=hf +MODEL_MAX_LENGTH=8192 +CHAT_TEMPLATE_TYPE=qwen3_vl +OUTPUT_DIR= + +torchrun --nproc_per_node=8 tools/generate_hidden_for_draft_model.py \ + --modal_type VLM \ + --dataset_path $DATASET_PATH \ + --model_name $MODEL_NAME \ + --target_backend $TARGET_BACKEND \ + --torch_dtype bfloat16 \ + --model_max_length $MODEL_MAX_LENGTH \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --outdir $OUTPUT_DIR \ + --num_proc 8 \ + --target_model_type qwen3_vl diff --git a/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh b/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh new file mode 100644 index 00000000..d2b8fcc6 --- /dev/null +++ b/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH=Qwen/Qwen3-VL-4B-Instruct +export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/qwen3-vl-4b-eagle3-mrope.json +export TRAIN_DATA_PATH= +export TRAIN_HIDDEN_PATH= +export EVAL_HIDDEN_PATH= +export OUTPUT_DIR= +export RUN_NAME=qwen3-4b-eagle3-angelslim +export MODEL_MAX_LENGTH=8192 +export LM_HEAD_KEY="model.language_model.embed_tokens.weight" +export CHAT_TEMPLATE_TYPE=qwen3_vl +export EMBED_WEIGHT_KEY="model.language_model.embed_tokens.weight" + +torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ + --modal_type VLM \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --train_hidden_path $TRAIN_HIDDEN_PATH \ + --eval_hidden_path $EVAL_HIDDEN_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 20 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --save_strategy "steps" \ + --eval_strategy "steps" \ + --save_steps 5000 \ + --eval_steps 20000 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 100 \ + --model_max_length $MODEL_MAX_LENGTH \ + --lm_head_key $LM_HEAD_KEY \ + --embed_weight_key $EMBED_WEIGHT_KEY \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --report_to none \ + --run_name $RUN_NAME \ + --num_proc 8 \ + --training_time_test_length 3 \ + --bf16 diff --git a/scripts/speculative/qwen3_vl/train_eagle3_vlm_online.sh b/scripts/speculative/qwen3_vl/train_eagle3_vlm_online.sh new file mode 100644 index 00000000..c8840a24 --- /dev/null +++ b/scripts/speculative/qwen3_vl/train_eagle3_vlm_online.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH=Qwen/Qwen3-VL-4B-Instruct +export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/qwen3-vl-4b-eagle3-mrope.json +export TRAIN_DATA_PATH= +export EVAL_DATA_PATH= +export OUTPUT_DIR= +export EMBED_WEIGHT_KEY="model.language_model.embed_tokens.weight" +export MODEL_MAX_LENGTH=4096 +export CHAT_TEMPLATE_TYPE=qwen3_vl +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +torchrun --nproc_per_node=8 tools/train_eagle3_online.py \ + --modal_type VLM \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --eval_data_path $EVAL_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 20 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --num_proc 8 \ + --save_strategy "steps" \ + --save_steps 1000 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 20 \ + --model_max_length $MODEL_MAX_LENGTH \ + --embed_weight_key $EMBED_WEIGHT_KEY \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --report_to none \ + --run_name qwen3-4b-eagle3-angelslim \ No newline at end of file diff --git a/tools/generate_hidden_for_draft_model.py b/tools/generate_hidden_for_draft_model.py index 40e3dc15..334ff1d8 100644 --- a/tools/generate_hidden_for_draft_model.py +++ b/tools/generate_hidden_for_draft_model.py @@ -15,6 +15,7 @@ import argparse import logging import os +from datetime import timedelta from pathlib import Path from typing import Any, Dict, Tuple @@ -46,7 +47,7 @@ def setup_distributed(): local_rank = int(os.environ["LOCAL_RANK"]) # Initialize process group - dist.init_process_group(backend="nccl") + dist.init_process_group(backend="nccl", timeout=timedelta(minutes=60)) torch.cuda.set_device(local_rank) return rank, world_size, local_rank @@ -123,13 +124,14 @@ def _process_single_sample(self, idx: int, row: Dict[str, Any]) -> bool: try: # Generate aux and target hiddens device = decide_device_for_distributed() - results = self.target_model.get_aux_and_target_hiddens( - input_ids=row["input_ids"].to(device), - ) + for k, v in row.items(): + if isinstance(v, torch.Tensor) and v is not None: + row[k] = v.to(device) + results = self.target_model.get_aux_and_target_hiddens(**row) # hidden_states: B, N, 3*D # target_hiddens: B, N, D for k, v in results.items(): - results[k] = v.cpu() + results[k] = v.cpu() if isinstance(v, torch.Tensor) else v # Prepare data point data_point = { @@ -231,6 +233,12 @@ def parse_arguments() -> argparse.Namespace: type=str, help="Target model name or path (if different from model_name)", ) + parser.add_argument( + "--target_model_type", + type=str, + default=None, + help="Target model name or path (if different from model_name)", + ) parser.add_argument( "--target_backend", type=str, @@ -339,6 +347,7 @@ def load_dataset(args: argparse.Namespace, tokenizer, rank: int): dataset_manager = DatasetManager( data_args=args, tokenizer=tokenizer, + target_model_type=args.target_model_type, model_max_length=args.model_max_length, chat_template_type=args.chat_template_type, display=display, @@ -423,6 +432,7 @@ def main(): model_path=args.target_model_name_or_path or args.model_name, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code, + target_model_type=args.target_model_type, ) # Load dataset diff --git a/tools/train_eagle3_offline.py b/tools/train_eagle3_offline.py index 0b50b792..9354ee4e 100644 --- a/tools/train_eagle3_offline.py +++ b/tools/train_eagle3_offline.py @@ -254,6 +254,9 @@ def parse_args(): training_group.add_argument( "--save_strategy", type=str, default="no", help="Save strategy for checkpoints" ) + training_group.add_argument( + "--eval_strategy", type=str, default="no", help="Eval strategy for checkpoints" + ) training_group.add_argument( "--lr_scheduler_type", type=str, @@ -317,11 +320,14 @@ def train(): f"(chat template: {args.chat_template_type})" ) + target_model_type = getattr(draft_model_config, "target_model_type", None) + dataset_manager = DatasetManager( data_args=args, tokenizer=tokenizer, model_max_length=args.model_max_length, chat_template_type=args.chat_template_type, + target_model_type=target_model_type, ) ( @@ -380,6 +386,7 @@ def train(): } checkpoint_args = { + "eval_strategy": args.eval_strategy, "save_strategy": args.save_strategy, "save_steps": args.save_steps, "save_total_limit": args.save_total_limit, diff --git a/tools/train_eagle3_online.py b/tools/train_eagle3_online.py index 48b473d8..ca09237e 100644 --- a/tools/train_eagle3_online.py +++ b/tools/train_eagle3_online.py @@ -276,6 +276,10 @@ def train(): } torch_dtype = dtype_mapping.get(args.torch_dtype, torch.bfloat16) + rank0_print("Loading draft model config...") + draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) + target_model_type = getattr(draft_model_config, "target_model_type", None) + # Create target model with specified backend using factory function rank0_print(f"Loading target model with {args.target_backend} backend...") target_model = create_target_model( @@ -284,12 +288,12 @@ def train(): modal_type=args.modal_type, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code, + target_model_type=target_model_type, ) rank0_print("Target model loaded successfully") # Create draft model rank0_print("Loading draft model...") - draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) rank0_print(f"draft_model_config: {draft_model_config}") draft_model = create_draft_model(draft_model_config) draft_model.load_embed_weights( @@ -309,6 +313,7 @@ def train(): model_max_length=args.model_max_length, chat_template_type=args.chat_template_type, display=args.display, + target_model_type=target_model_type, ) train_dataset, eval_dataset, data_collator = ( dataset_manager.create_online_datasets() diff --git a/tools/vllm_offline_eagle3_vlm_batch.py b/tools/vllm_offline_eagle3_vlm_batch.py new file mode 100755 index 00000000..b9446e6b --- /dev/null +++ b/tools/vllm_offline_eagle3_vlm_batch.py @@ -0,0 +1,361 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +usage: +for task in "Lin-Chen/MMStar" "HuggingFaceH4/MATH-500" "MMMU/MMMU"; do + python3 ./tools/vllm_offline_eagle3_vlm_batch.py \ + --target_model "$MODEL_DIR" \ + --draft_model "$EAGLE_DIR" \ + --use_eagle \ + --num_spec_tokens 4 \ + --dataset "$task" \ + --num_prompts 80 \ + --temp 0 \ + --max_num_seqs 1 \ + --output_len 1024 \ + --output_file "$OUTPUT_FILE" +done +""" + +import argparse +import base64 +import json +import os +import time +from io import BytesIO + +from datasets import load_dataset +from PIL import Image +from vllm import LLM, SamplingParams + + +def pil_to_base64(img): + if img.mode != "RGB": + img = img.convert("RGB") + buffered = BytesIO() + img.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return f"data:image/jpeg;base64,{img_str}" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--target_model", type=str, default="Qwen/Qwen3-VL-4B-Instruct") + parser.add_argument( + "--draft_model", type=str, default=None, help="Path to draft model" + ) + parser.add_argument("--dataset", type=str, default="textvqa", help="Dataset to use") + parser.add_argument( + "--use_eagle", + action="store_true", + help="Enable speculative decoding with Eagle", + ) + parser.add_argument( + "--output_file", type=str, default="results/qwen3-vl-4b-eagle3-results.jsonl" + ) + parser.add_argument( + "--num_prompts", + type=int, + default=None, + help="Number of prompts to run (default: all)", + ) + parser.add_argument( + "--num_spec_tokens", type=int, default=2, help="Number of speculative tokens" + ) + parser.add_argument("--max_num_seqs", type=int, default=1) + parser.add_argument( + "--temp", type=float, default=0, help="Number of speculative tokens" + ) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--output_len", type=int, default=1024) + return parser.parse_args() + + +CAT_SHORT2LONG = { + "acc": "Accounting", + "agri": "Agriculture", + "arch": "Architecture_and_Engineering", + "art": "Art", + "art_theory": "Art_Theory", + "bas_med": "Basic_Medical_Science", + "bio": "Biology", + "chem": "Chemistry", + "cli_med": "Clinical_Medicine", + "cs": "Computer_Science", + "design": "Design", + "diag_med": "Diagnostics_and_Laboratory_Medicine", + "econ": "Economics", + "elec": "Electronics", + "ep": "Energy_and_Power", + "fin": "Finance", + "geo": "Geography", + "his": "History", + "liter": "Literature", + "manage": "Manage", + "mark": "Marketing", + "mate": "Materials", + "math": "Math", + "mech": "Mechanical_Engineering", + "music": "Music", + "phar": "Pharmacy", + "phys": "Physics", + "psy": "Psychology", + "pub_health": "Public_Health", + "socio": "Sociology", +} + + +def main(): + args = parse_args() + + # Load dataset + print(f"Loading {args.dataset} dataset...") + if args.dataset == "MMMU/MMMU": + ds = load_dataset("args.dataset", split="test", trust_remote_code=True) + elif args.dataset == "Lin-Chen/MMStar": + ds = load_dataset(args.dataset, split="val", trust_remote_code=True) + elif args.dataset == "hunyuan-ocr": + ds = load_dataset("./dataset/hunyuan-ocr", split="test", trust_remote_code=True) + else: + ds = load_dataset(args.dataset, split="test", trust_remote_code=True) + if args.num_prompts is not None: + ds = ds.select(range(min(args.num_prompts, len(ds)))) + + print(f"Loaded {len(ds)} samples.") + + prompts = [] + if args.dataset == "lmms-lab/textvqa": + for item in ds: + # Convert PIL image to base64 + image_url = pil_to_base64(item["image"]) + prompts.append( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": item["question"]}, + ], + } + ] + ) + elif args.dataset == "MMMU/MMMU": + for item in ds: + # Convert PIL image to base64 + image_url = pil_to_base64(item["image_1"]) + question = item["question"].replace("", "") + prompts.append( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": question}, + ], + } + ] + ) + elif args.dataset == "Lin-Chen/MMStar": + for item in ds: + # Convert PIL image to base64 + image_url = pil_to_base64(item["image"]) + question = item["question"].replace("", "") + prompts.append( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": question}, + ], + } + ] + ) + elif args.dataset == "HuggingFaceH4/MATH-500": + for item in ds: + prompts.append([{"role": "user", "content": item["problem"]}]) + elif args.dataset == "hunyuan-ocr": + for item in ds: + image_url = pil_to_base64( + Image.open( + os.path.join("./dataset/hunyuan-ocr/images", item["img_path"]) + ) + ) + prompts.append( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": item["question"]}, + ], + } + ] + ) + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") + + speculative_config = None + if args.use_eagle: + if args.draft_model: + speculative_config = { + "method": "eagle3", + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + } + else: + print( + "Warning: use_eagle is set but no draft_model provided. " + "Running without speculative decoding." + ) + + print( + f"Initializing LLM with target_model={args.target_model}, " + f"speculative_config={speculative_config}" + ) + + llm = LLM( + model=args.target_model, + trust_remote_code=True, + tensor_parallel_size=args.tp, + gpu_memory_utilization=0.9, + speculative_config=speculative_config, + max_num_seqs=args.max_num_seqs, + enforce_eager=True, + disable_log_stats=False, + max_model_len=32768, + limit_mm_per_prompt={"image": 1}, + disable_chunked_mm_input=False, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + + print("Starting generation...") + start_time = time.perf_counter() + outputs = llm.chat(prompts, sampling_params=sampling_params) + end_time = time.perf_counter() + total_time = end_time - start_time + print(f"Generation finished in {total_time:.2f} seconds.") + + # Process results + results_data = [] + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + if args.dataset == "lmms-lab/textvqa": + results_data.append( + { + "question_id": ds[i]["question_id"], + "image_id": ds[i]["image_id"], + "question": ds[i]["question"], + "generated_text": generated_text, + "answers": ds[i]["answers"], + } + ) + elif args.dataset == "HuggingFaceH4/MATH-500": + results_data.append( + { + "problem": ds[i]["problem"], + "generated_text": generated_text, + "solution": ds[i].get("solution", ""), + "answer": ds[i].get("answer", ""), + } + ) + elif args.dataset == "MMMU/MMMU": + results_data.append( + { + "id": ds[i]["id"], + "question": ds[i]["question"], + "generated_text": generated_text, + "answer": ds[i]["answer"], + } + ) + elif args.dataset == "Lin-Chen/MMStar": + results_data.append( + { + "id": ds[i]["index"], + "question": ds[i]["question"], + "generated_text": generated_text, + "answer": ds[i]["answer"], + } + ) + elif args.dataset == "hunyuan-ocr": + results_data.append( + { + "question": ds[i]["question"], + "generated_text": generated_text, + "answer": ds[i]["answer"], + } + ) + + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + + metrics_info = { + "total_time": total_time, + "avg_time_per_sample": total_time / len(prompts) if prompts else 0, + "use_eagle": args.use_eagle, + "output_throughput": total_num_output_tokens / total_time, + } + + if args.use_eagle and speculative_config: + try: + metrics = llm.get_metrics() + + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * args.num_spec_tokens + + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + acceptance_rates = {} + for i in range(len(acceptance_counts)): + acceptance_rate = ( + acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + ) + acceptance_rates[f"acceptance_rate_pos_{i}"] = round(acceptance_rate, 4) + acceptance_length = ( + 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + ) + metrics_info["mean_acceptance_length"] = acceptance_length + metrics_info["num_drafts"] = num_drafts + metrics_info["num_accepted_tokens"] = num_accepted_tokens + metrics_info["acceptance_rates"] = acceptance_rates + + print(f"Mean acceptance length: {acceptance_length:.2f}") + print(f"acceptance rates: {acceptance_rates}") + except Exception as e: + print(f"Error getting metrics: {e}") + + # Save to file + os.makedirs(os.path.dirname(args.output_file), exist_ok=True) + with open(args.output_file, "w") as f: + json.dump({"metrics": metrics_info, "results": results_data}, f, indent=2) + + print(f"Results saved to {args.output_file}") + + +if __name__ == "__main__": + main() diff --git a/tools/vllm_spec_benchmark.py b/tools/vllm_spec_benchmark.py index f80bd40c..5d124961 100644 --- a/tools/vllm_spec_benchmark.py +++ b/tools/vllm_spec_benchmark.py @@ -96,7 +96,7 @@ def save_stats_to_jsonl(stats, output_file): # Append stats to jsonl file with open(output_file, "a", encoding="utf-8") as f: - json.dump(stats, f, ensure_ascii=False) + json.dump(stats, f, indent=2, ensure_ascii=False) f.write("\n") print(f"Stats saved to: {output_file}")