diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 808e928c..6a66e3d6 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -20,7 +20,7 @@ from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from specforge.args import SGLangBackendArgs, TrackerArgs from specforge.core.dflash import OnlineDFlashModel from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders @@ -35,6 +35,8 @@ from specforge.tracker import create_tracker from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank +logging.getLogger("sglang.srt.mem_cache.memory_pool").setLevel(logging.WARNING) + def parse_args(): parser = argparse.ArgumentParser(description="Train DFlash Draft Model") @@ -94,6 +96,9 @@ def parse_args(): default=None, help="LM head weight key in the target model. Default: 'lm_head.weight'.", ) + model_group.add_argument("--is-vlm", action="store_true") + model_group.add_argument("--min-pixels", type=int, default=50176) + model_group.add_argument("--max-pixels", type=int, default=802816) dataset_group = parser.add_argument_group("dataset") dataset_group.add_argument("--train-data-path", type=str, required=True) @@ -207,7 +212,16 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: return target_model, draft_model -def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: +def _load_raw_dataset(data_path: str): + """load jsonl""" + if os.path.isdir(data_path): + return load_from_disk(data_path) + return load_dataset("json", data_files=data_path)["train"] + + +def build_dataloader( + args, tokenizer, processor=None +) -> Tuple[DataLoader, Optional[DataLoader]]: """Build train and eval dataloaders.""" import hashlib @@ -219,7 +233,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_dataset = _load_raw_dataset(args.train_data_path) train_eagle3_dataset = build_eagle3_dataset( dataset=train_dataset, tokenizer=tokenizer, @@ -229,8 +243,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, num_proc=args.build_dataset_num_proc, + is_vlm=args.is_vlm, + processor=processor, ) - min_loss_tokens = 2 * args.block_size original_size = len(train_eagle3_dataset) train_eagle3_dataset = train_eagle3_dataset.filter( @@ -246,17 +261,20 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=True, process_group=get_dp_group(), + is_vlm=args.is_vlm, ) eval_dataloader = None if args.eval_data_path: - eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_dataset = _load_raw_dataset(args.eval_data_path) eval_eagle3_dataset = build_eagle3_dataset( dataset=eval_dataset, tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, + is_vlm=args.is_vlm, + processor=processor, ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, @@ -264,6 +282,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=False, process_group=get_dp_group(), + is_vlm=args.is_vlm, ) return train_dataloader, eval_dataloader @@ -396,7 +415,23 @@ def main(): f"step {resume_state['global_step']}" ) - tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + + processor = None + if args.is_vlm: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + trust_remote_code=args.trust_remote_code, + ) + print_on_rank0( + f"Loaded VLM processor (min_pixels={args.min_pixels}, max_pixels={args.max_pixels})" + ) if args.mask_token_id is not None: mask_token_id = args.mask_token_id @@ -412,8 +447,7 @@ def main(): draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}") - train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) - + train_dataloader, eval_dataloader = build_dataloader(args, tokenizer, processor) steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) total_steps = args.num_epochs * steps_per_epoch print_on_rank0(f"Total training steps: {total_steps}") @@ -496,11 +530,30 @@ def main(): continue global_step += 1 + input_ids_cpu = data["input_ids"] + attention_mask_cpu = data["attention_mask"] + loss_mask_cpu = data["loss_mask"] + input_ids = data["input_ids"].cuda() - attention_mask = data["attention_mask"].cuda() loss_mask = data["loss_mask"].cuda() + pixel_values = None + image_grid_thw_cpu = None + if ( + args.is_vlm + and "pixel_values" in data + and data["pixel_values"] is not None + ): + pixel_values = data["pixel_values"].cuda() + image_grid_thw_cpu = [ + thw.squeeze() if thw is not None else None + for thw in data["image_grid_thw"] + ] target_output = target_model.generate_dflash_data( - input_ids, attention_mask, loss_mask + input_ids_cpu, + attention_mask_cpu, + loss_mask_cpu, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw_cpu, ) hidden_states = target_output.hidden_states.cuda() # Ensure on GPU diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index d5af9479..426ddf5e 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -202,8 +202,6 @@ def preprocess_vlm_conversations( - pixel_values: List of pixel values for images in the examples. - image_grid_thw: List of image grid tensors. """ - system_prompt = chat_template.system_prompt - # prepare result results = { "input_ids": [], @@ -213,36 +211,71 @@ def preprocess_vlm_conversations( "image_grid_thw": [], } - # Note: currently, we assume that each example has only one image - for i, image in enumerate(examples["image"]): + for i, images in enumerate(examples["images"]): source = examples["conversations"][i] - messages = [{"role": "system", "content": system_prompt}] + messages = [] if not source: # if the source is None, skip it continue + if not images: + text_messages = [] + convroles = ["user", "assistant"] + for j, sentence in enumerate(source): + role = sentence["role"] + assert role == convroles[j % 2], f"unexpected role {role}" + text_messages.append({"role": role, "content": sentence["content"]}) + conversation = processor.apply_chat_template( + text_messages, + tokenize=False, + add_generation_prompt=False, + ) + encoding = processor( + text=[conversation], + max_length=max_length, + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + add_special_tokens=False, + ) + + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + + # get conversation with image info for loss mask generation + decoded_conversation = processor.tokenizer.decode( + encoding.input_ids[0], skip_special_tokens=False + ) + + # Apply loss mask + loss_mask = _apply_loss_mask_from_chat_template( + decoded_conversation, offsets, chat_template + ) + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + results["pixel_values"].append(torch.empty(0, 0).float()) + results["image_grid_thw"].append([]) + continue + if source[0]["role"] != "user": # if the first message is not from user, skip it source = source[1:] convroles = ["user", "assistant"] + has_added_images = False for j, sentence in enumerate(source): role = sentence["role"] assert role == convroles[j % 2], f"unexpected role {role}" if role == "user": - # if the message is from user and has image, process the image - messages.append( - { - "role": role, - "content": [ - { - "type": "image", - "image": image, - }, - {"type": "text", "text": sentence["content"]}, - ], - } - ) + # Insert all images into the first user message + if not has_added_images: + content = [{"type": "image", "image": img} for img in images] + content.append({"type": "text", "text": sentence["content"]}) + messages.append({"role": role, "content": content}) + has_added_images = True + else: + messages.append({"role": role, "content": sentence["content"]}) else: messages.append({"role": role, "content": sentence["content"]}) @@ -273,7 +306,7 @@ def preprocess_vlm_conversations( input_ids = encoding.input_ids[0] offsets = encoding.offset_mapping[0] pixel_values = encoding.pixel_values - image_grid_thw = encoding.image_grid_thw[0] + image_grid_thw = encoding.image_grid_thw # shape: (num_images, 3) # get conversation with image info for loss mask generation decoded_conversation = processor.tokenizer.decode( @@ -289,7 +322,7 @@ def preprocess_vlm_conversations( results["loss_mask"].append(loss_mask[None, :]) results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) results["pixel_values"].append(pixel_values) - results["image_grid_thw"].append(image_grid_thw[None, :]) + results["image_grid_thw"].append(image_grid_thw) return results @@ -390,7 +423,7 @@ def preprocess_function(examples): # Parse tools: handle JSON strings from safe_conversations_generator tools = [] for tool_item in tools_raw: - if isinstance(tool_item, (str, list)): + if isinstance(tool_item, str): try: tools.append(json.loads(tool_item)) except json.JSONDecodeError: diff --git a/specforge/data/utils.py b/specforge/data/utils.py index 93fd6f58..2d7a5c22 100644 --- a/specforge/data/utils.py +++ b/specforge/data/utils.py @@ -205,6 +205,10 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - attention_mask: torch.Tensor of shape (B, N) - loss_mask: torch.Tensor of shape (B, N) """ + assert len(features) == 1, ( + f"VlmDataCollatorWithPadding requires batch_size=1, got {len(features)}. " + "Set per_device_train_batch_size=1 in your training config." + ) max_length = max(item["input_ids"].shape[1] for item in features) batch_input_ids = torch.cat( [self.paddingtensor2D(item["input_ids"], max_length) for item in features] @@ -218,12 +222,30 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: batch_loss_mask = torch.cat( [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] ) - batch_pixel_values = torch.cat( - [item["pixel_values"] for item in features], dim=0 - ) - batch_image_grid_thw = torch.cat( - [item["image_grid_thw"] for item in features], dim=0 - ) + # Collect pixel_values and image_grid_thw per sample. + # Image samples have non-empty pixel_values; text-only samples have empty tensors. + all_pixel_values = [] + all_image_grid_thw = [] + for item in features: + pv = item.get("pixel_values") + thw = item.get("image_grid_thw") + if pv is not None and isinstance(pv, torch.Tensor) and pv.numel() > 0: + all_pixel_values.append(pv) + all_image_grid_thw.append(thw) + else: + all_image_grid_thw.append(None) + + if all_pixel_values: + batch_pixel_values = torch.cat(all_pixel_values, dim=0) + else: + batch_pixel_values = None + + # If all samples are text-only, set image_grid_thw to None + if all(thw is None for thw in all_image_grid_thw): + batch_image_grid_thw = None + else: + batch_image_grid_thw = all_image_grid_thw + batch = { "input_ids": batch_input_ids, "attention_mask": batch_attention_mask, @@ -304,17 +326,10 @@ def prepare_dp_dataloaders( def parse_harmony_message_content(content): - """ - 解析 content 字符串中的 Harmony 格式。 - 如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表; - 否则,返回原内容并标记为默认 channel。 - """ - # 匹配 <|channel|>xxx<|message|>yyy<|end|> pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>" matches = re.findall(pattern, content, re.DOTALL) if not matches: - # 如果没有匹配到 Harmony 标签,视作普通文本 return [{"channel": "text", "content": content}] results = [] @@ -324,22 +339,17 @@ def parse_harmony_message_content(content): def process_harmony_conversations(conversation): - """ - 处理传入的 list[list[dict]] 结构 - """ new_conversation = [] for msg in conversation: role = msg.get("role") original_content = msg.get("content", "") - # 解析 content 中的 Harmony 结构 segments = parse_harmony_message_content(original_content) - # 为每个解析出的通道生成一个新的消息字典 for seg in segments: new_msg = { "role": role, - "channel": seg["channel"], # 新增字段标识通道 + "channel": seg["channel"], "content": seg["content"], } new_conversation.append(new_msg) diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 0df93823..d33cc57d 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -2,15 +2,28 @@ from dataclasses import dataclass from typing import List, Optional +import sglang.srt.managers.mm_utils as mm_utils import torch import torch.distributed as dist import torch.nn as nn from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + init_mm_embedding_cache, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + Req, + ScheduleBatch, +) from sglang.srt.managers.scheduler import Scheduler from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -56,6 +69,8 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, ) -> DFlashTargetOutput: """Generate context hidden states for DFlash training.""" @@ -68,6 +83,34 @@ class SGLangDFlashTargetModel(DFlashTargetModel): def __init__(self, model_runner: SGLangRunner): super().__init__() self.model_runner = model_runner + self._init_vlm_attributes() + cache_params = CacheInitParams( + disable=True, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + self.dummy_tree_cache = RadixCache(cache_params) + + def _init_vlm_attributes(self): + """Detect and cache VLM-specific token IDs / vision config from the HF config.""" + hf_config = self.model_runner.model_config.hf_config + self.is_vlm = hasattr(hf_config, "vision_config") + if not self.is_vlm: + return + + init_mm_embedding_cache(512 * 1024 * 1024) # 512 MB embedding cache + + self.image_token_id = getattr(hf_config, "image_token_id", None) + self.video_token_id = getattr(hf_config, "video_token_id", None) + self.vision_start_token_id = getattr(hf_config, "vision_start_token_id", None) + self.vision_end_token_id = getattr(hf_config, "vision_end_token_id", None) + + vision_config = hf_config.vision_config + self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self.vlm_model_type = getattr(vision_config, "model_type", "") + + self.tokens_per_second = None @classmethod def from_pretrained( @@ -85,7 +128,7 @@ def from_pretrained( trust_remote_code=trust_remote_code, dtype=torch_dtype, enable_return_hidden_states=True, # Critical for DFlash - disable_cuda_graph=True, + # disable_cuda_graph=True, tp_size=tp_size, pp_size=1, **kwargs, @@ -112,25 +155,20 @@ def from_pretrained( def set_capture_layers(self, layer_ids: List[int]) -> None: super().set_capture_layers(layer_ids) - if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): + if hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + self.model_runner.model.set_dflash_layers_to_capture( + [val + 1 for val in layer_ids] + ) + elif hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) - print(self.model_runner.model.model.layers_to_capture) - @torch.no_grad + @torch.no_grad() def _extend(self, reqs): - cache_params = CacheInitParams( - disable=False, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, - page_size=self.model_runner.server_args.page_size, - ) - tree_cache = RadixCache(cache_params) - batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, - tree_cache=tree_cache, + tree_cache=self.dummy_tree_cache, model_config=self.model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, @@ -180,42 +218,167 @@ def _extend(self, reqs): return hidden_states_list + def _build_vlm_reqs( + self, + input_ids_list: List[torch.Tensor], + pixel_values: Optional[List[torch.Tensor]] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, + ): + mm_utils.embedding_cache.clear() + sampling_params = SamplingParams(temperature=0, max_new_tokens=1) + reqs = [] + + # pixel_values is a single 2D tensor (total_patches, patch_dim) for Qwen2.5-VL + # We need to track offset and slice it based on image_grid_thw for each sample + pixel_values_offset = 0 # Track current offset in pixel_values + for idx, (input_id_, image_grid_thw_) in enumerate( + zip( + input_ids_list, + image_grid_thw, + ) + ): + input_id_flat = input_id_.view(-1) + + # Determine if this sample has image data + has_image = ( + image_grid_thw_ is not None + and isinstance(image_grid_thw_, torch.Tensor) + and image_grid_thw_.numel() > 0 + ) + + if has_image: + # Ensure image_grid_thw_ is 2D: (num_images, 3) + if image_grid_thw_.dim() == 1: + image_grid_thw_ = image_grid_thw_.unsqueeze(0) # (3,) -> (1, 3) + elif image_grid_thw_.dim() == 0: + raise ValueError( + f"image_grid_thw_ is 0-dim tensor, expected at least 1D. Value: {image_grid_thw_}" + ) + + # Calculate num_patches for this sample: sum(t * h * w) for all images + num_patches = ( + ( + image_grid_thw_[:, 0] + * image_grid_thw_[:, 1] + * image_grid_thw_[:, 2] + ) + .sum() + .item() + ) + num_patches = int(num_patches) + + # Slice pixel_values for this sample + pixel_value_ = pixel_values[ + pixel_values_offset : pixel_values_offset + num_patches + ] + pixel_values_offset += num_patches + + # Compute mrope positions for VLM models (e.g., Qwen2.5-VL) + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.vlm_model_type, + input_ids=input_id_flat.unsqueeze(0), + image_grid_thw=image_grid_thw_, + tokens_per_second=self.tokens_per_second, + ) + + offset = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.image_token_id + ) + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_value_, # torch.Tensor: (num_patches, patch_dim) + pad_value=self.image_token_id, # Required for placeholder tensor creation + offsets=offset, # List of (start, end) tuples + ) + mm_item.set("image_grid_thw", image_grid_thw_) + mm_item.set_pad_value() + mm_inputs = MultimodalInputs( + mm_items=[mm_item], + im_token_id=self.image_token_id, + im_start_id=self.vision_start_token_id, + im_end_id=self.vision_end_token_id, + mrope_positions=( + mrope_positions.squeeze(1) + if mrope_positions is not None + else None + ), + mrope_position_delta=mrope_position_delta, + ) + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + input_id_list = pattern.pad_input_tokens( + input_id_flat.tolist(), mm_inputs + ) + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_list, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + req.multimodal_inputs = mm_inputs + reqs.append(req) + else: + # Text-only sample: create plain text req without multimodal inputs + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_flat.tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + reqs.append(req) + return reqs + @torch.no_grad() def generate_dflash_data( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: - sampling_params = SamplingParams(temperature=0, max_new_tokens=1) - reqs, data_cache = [], [] if isinstance(input_ids, torch.Tensor): input_ids_list = torch.split(input_ids, 1, dim=0) attn_mask_list = torch.split(attention_mask, 1, dim=0) loss_mask_list = torch.split(loss_mask, 1, dim=0) - for idx, (curr_ids, curr_attn, curr_loss) in enumerate( - zip(input_ids_list, attn_mask_list, loss_mask_list) - ): - req = Req( - rid=str(idx), - origin_input_text="", - origin_input_ids=curr_ids.view(-1).tolist(), - sampling_params=sampling_params, - ) - req.fill_ids = req.origin_input_ids - req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) - data_cache.append((curr_ids, curr_attn, curr_loss)) - reqs.append(req) + is_vlm_batch = ( + self.is_vlm and pixel_values is not None and image_grid_thw is not None + ) + + if is_vlm_batch: + reqs = self._build_vlm_reqs(input_ids_list, pixel_values, image_grid_thw) + else: + sampling_params = SamplingParams(temperature=0, max_new_tokens=1) + reqs = [] + for idx, curr_ids in enumerate(input_ids_list): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=curr_ids.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + reqs.append(req) hidden_states_list = self._extend(reqs) # Stack back to batch hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) - input_ids = torch.cat([d[0] for d in data_cache], dim=0) - attention_mask = torch.cat([d[1] for d in data_cache], dim=0) - loss_mask = torch.cat([d[2] for d in data_cache], dim=0) + input_ids = torch.cat(list(input_ids_list), dim=0) + attention_mask = torch.cat(list(attn_mask_list), dim=0) + loss_mask = torch.cat(list(loss_mask_list), dim=0) return DFlashTargetOutput( hidden_states=hidden_states, @@ -261,13 +424,20 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[List[torch.Tensor]] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, ) -> DFlashTargetOutput: - outputs = self.model( + model_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) + if pixel_values is not None: + model_kwargs["pixel_values"] = pixel_values + if image_grid_thw is not None: + model_kwargs["image_grid_thw"] = image_grid_thw + outputs = self.model(**model_kwargs) # hidden_states[0] = embedding output; hidden_states[i+1] = layer i output offset = 1 @@ -296,6 +466,8 @@ def get_dflash_target_model( **kwargs, ) -> DFlashTargetModel: if backend == "sglang": + if "enable_piecewise_cuda_graph" in kwargs: + del kwargs["enable_piecewise_cuda_graph"] return SGLangDFlashTargetModel.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, torch_dtype=torch_dtype, diff --git a/specforge/modeling/target/sglang_backend/model_runner.py b/specforge/modeling/target/sglang_backend/model_runner.py index 501ee34e..5441fb04 100644 --- a/specforge/modeling/target/sglang_backend/model_runner.py +++ b/specforge/modeling/target/sglang_backend/model_runner.py @@ -116,17 +116,9 @@ def _(data, dim): rank=self.tp_size * self.pp_rank + self.tp_rank, local_rank=self.gpu_id, ) - # NOTE: Updated for sglang 0.5.9 - # - Removed torch_compile parameter (no longer supported) - # - Added new parameters: attention_data_parallel_size, attention_context_model_parallel_size, moe_data_model_parallel_size - - # Debug: Print the values dp_size = getattr(self.server_args, "dp_size", 1) attn_cp_size = getattr(self.server_args, "attn_cp_size", 1) moe_dp_size = getattr(self.server_args, "moe_dp_size", 1) - print( - f"[DEBUG] tp_size={self.tp_size}, dp_size={dp_size}, attn_cp_size={attn_cp_size}, moe_dp_size={moe_dp_size}" - ) initialize_model_parallel( tensor_model_parallel_size=self.tp_size, diff --git a/specforge/modeling/target/target_utils.py b/specforge/modeling/target/target_utils.py index 9dacba6b..a07bc9e1 100644 --- a/specforge/modeling/target/target_utils.py +++ b/specforge/modeling/target/target_utils.py @@ -54,10 +54,19 @@ def from_pretrained( config = AutoConfig.from_pretrained( model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code ) + + if hasattr(config, "text_config") and config.text_config is not None: + tc = config.text_config + for attr in ("vocab_size", "hidden_size", "pad_token_id"): + if getattr(config, attr, None) is None: + setattr(config, attr, getattr(tc, attr, None)) instance = cls(config) if embed_key is None: - embed_key = "model.embed_tokens.weight" + if hasattr(config, "text_config") and config.text_config is not None: + embed_key = "model.language_model.embed_tokens.weight" + else: + embed_key = "model.embed_tokens.weight" if lm_head_key is None: lm_head_key = "lm_head.weight" diff --git a/tests/test_data/test_parsers.py b/tests/test_data/test_parsers.py index 63b673dc..955b38ca 100644 --- a/tests/test_data/test_parsers.py +++ b/tests/test_data/test_parsers.py @@ -328,11 +328,7 @@ def test_qwen3_instruct_with_tools(self): ) def test_qwen35_instruct(self): - self._run_template_test( - "Qwen/Qwen3.5-35B-A3B", - "qwen3.5", - messages=self.complete_reasoning_tool_conversation, - ) + self._run_template_test("Qwen/Qwen3.5-35B-A3B", "qwen3.5") if __name__ == "__main__": diff --git a/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py b/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py index da195343..bc5e74a3 100644 --- a/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py +++ b/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py @@ -32,6 +32,14 @@ def test_dense(rank, world_size, port, tp_size): device="cuda", attention_backend="fa3", mem_fraction_static=0.4, + enable_nccl_nvls=True, + enable_symm_mem=False, + enable_torch_compile=True, + enable_dp_attention=False, + enable_dp_lm_head=False, + enable_piecewise_cuda_graph=True, + ep_size=1, + context_length=256, ) sgl_target_model.set_aux_hidden_states_layers() sgl_out = sgl_target_model.generate_eagle3_data( @@ -62,6 +70,14 @@ def test_moe(rank, world_size, port, tp_size): device="cuda", attention_backend="fa3", mem_fraction_static=0.4, + enable_torch_compile=True, + enable_nccl_nvls=True, + enable_symm_mem=False, + enable_dp_attention=False, + enable_dp_lm_head=False, + enable_piecewise_cuda_graph=True, + ep_size=2, + context_length=256, ) sgl_target_model.set_aux_hidden_states_layers() sgl_out = sgl_target_model.generate_eagle3_data( @@ -193,6 +209,13 @@ def test_vlm(rank, world_size, port, tp_size): device="cuda", attention_backend="fa3", mem_fraction_static=0.75, + enable_torch_compile=True, + enable_nccl_nvls=True, + enable_symm_mem=False, # Disable to avoid nccl_allocator compilation issues + enable_dp_attention=False, + enable_dp_lm_head=False, + enable_piecewise_cuda_graph=True, + context_length=4096, ) sgl_target_model.set_aux_hidden_states_layers() sgl_out = sgl_target_model.generate_eagle3_data( @@ -346,6 +369,13 @@ def test_vlm_multi_batch(rank, world_size, port, tp_size): device="cuda", attention_backend="fa3", mem_fraction_static=0.4, + enable_nccl_nvls=True, + enable_torch_compile=True, + enable_symm_mem=False, + enable_dp_attention=False, + enable_dp_lm_head=False, + enable_piecewise_cuda_graph=True, + context_length=4096, ) sgl_target_model.set_aux_hidden_states_layers() sgl_out = sgl_target_model.generate_eagle3_data(