Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer

from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from specforge.args import SGLangBackendArgs, TrackerArgs
from specforge.core.dflash import OnlineDFlashModel
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
Expand All @@ -35,6 +35,8 @@
from specforge.tracker import create_tracker
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank

logging.getLogger("sglang.srt.mem_cache.memory_pool").setLevel(logging.WARNING)


def parse_args():
parser = argparse.ArgumentParser(description="Train DFlash Draft Model")
Expand Down Expand Up @@ -94,6 +96,9 @@ def parse_args():
default=None,
help="LM head weight key in the target model. Default: 'lm_head.weight'.",
)
model_group.add_argument("--is-vlm", action="store_true")
model_group.add_argument("--min-pixels", type=int, default=50176)
model_group.add_argument("--max-pixels", type=int, default=802816)

dataset_group = parser.add_argument_group("dataset")
dataset_group.add_argument("--train-data-path", type=str, required=True)
Expand Down Expand Up @@ -207,7 +212,16 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
return target_model, draft_model


def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]:
def _load_raw_dataset(data_path: str):
"""load jsonl"""
if os.path.isdir(data_path):
return load_from_disk(data_path)
return load_dataset("json", data_files=data_path)["train"]


def build_dataloader(
args, tokenizer, processor=None
) -> Tuple[DataLoader, Optional[DataLoader]]:
"""Build train and eval dataloaders."""
import hashlib

Expand All @@ -219,7 +233,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
)
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()

train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
train_dataset = _load_raw_dataset(args.train_data_path)
train_eagle3_dataset = build_eagle3_dataset(
dataset=train_dataset,
tokenizer=tokenizer,
Expand All @@ -229,8 +243,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
num_proc=args.build_dataset_num_proc,
is_vlm=args.is_vlm,
processor=processor,
)

min_loss_tokens = 2 * args.block_size
original_size = len(train_eagle3_dataset)
train_eagle3_dataset = train_eagle3_dataset.filter(
Expand All @@ -246,24 +261,28 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
num_workers=args.dataloader_num_workers,
shuffle=True,
process_group=get_dp_group(),
is_vlm=args.is_vlm,
)

eval_dataloader = None
if args.eval_data_path:
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
eval_dataset = _load_raw_dataset(args.eval_data_path)
eval_eagle3_dataset = build_eagle3_dataset(
dataset=eval_dataset,
tokenizer=tokenizer,
chat_template=args.chat_template,
max_length=args.max_length,
is_preformatted=args.is_preformatted,
is_vlm=args.is_vlm,
processor=processor,
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
args.batch_size,
num_workers=args.dataloader_num_workers,
shuffle=False,
process_group=get_dp_group(),
is_vlm=args.is_vlm,
)

return train_dataloader, eval_dataloader
Expand Down Expand Up @@ -396,7 +415,23 @@ def main():
f"step {resume_state['global_step']}"
)

tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)
tokenizer = AutoTokenizer.from_pretrained(
args.target_model_path, trust_remote_code=args.trust_remote_code
)

processor = None
if args.is_vlm:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
args.target_model_path,
min_pixels=args.min_pixels,
max_pixels=args.max_pixels,
trust_remote_code=args.trust_remote_code,
)
Comment on lines +426 to +431
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The exist_ok=True argument passed to AutoProcessor.from_pretrained does not appear to be a standard argument for this Hugging Face Transformers method. While it might be ignored if trust_remote_code=True allows for custom arguments in the model's loading code, it's not guaranteed and could lead to unexpected behavior or errors with different models or library versions. It would be safer to remove this argument if it's not strictly required by the Qwen-VL model's custom code.

print_on_rank0(
f"Loaded VLM processor (min_pixels={args.min_pixels}, max_pixels={args.max_pixels})"
)

if args.mask_token_id is not None:
mask_token_id = args.mask_token_id
Expand All @@ -412,8 +447,7 @@ def main():
draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids
print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}")

train_dataloader, eval_dataloader = build_dataloader(args, tokenizer)

train_dataloader, eval_dataloader = build_dataloader(args, tokenizer, processor)
steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps)
total_steps = args.num_epochs * steps_per_epoch
print_on_rank0(f"Total training steps: {total_steps}")
Expand Down Expand Up @@ -496,11 +530,30 @@ def main():
continue
global_step += 1

input_ids_cpu = data["input_ids"]
attention_mask_cpu = data["attention_mask"]
loss_mask_cpu = data["loss_mask"]

input_ids = data["input_ids"].cuda()
attention_mask = data["attention_mask"].cuda()
loss_mask = data["loss_mask"].cuda()
pixel_values = None
image_grid_thw_cpu = None
if (
args.is_vlm
and "pixel_values" in data
and data["pixel_values"] is not None
):
pixel_values = data["pixel_values"].cuda()
image_grid_thw_cpu = [
thw.squeeze() if thw is not None else None
for thw in data["image_grid_thw"]
]
target_output = target_model.generate_dflash_data(
input_ids, attention_mask, loss_mask
input_ids_cpu,
attention_mask_cpu,
loss_mask_cpu,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw_cpu,
)
hidden_states = target_output.hidden_states.cuda() # Ensure on GPU

Expand Down
75 changes: 54 additions & 21 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ def preprocess_vlm_conversations(
- pixel_values: List of pixel values for images in the examples.
- image_grid_thw: List of image grid tensors.
"""
system_prompt = chat_template.system_prompt

# prepare result
results = {
"input_ids": [],
Expand All @@ -213,36 +211,71 @@ def preprocess_vlm_conversations(
"image_grid_thw": [],
}

# Note: currently, we assume that each example has only one image
for i, image in enumerate(examples["image"]):
for i, images in enumerate(examples["images"]):
source = examples["conversations"][i]
messages = [{"role": "system", "content": system_prompt}]
messages = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few unused variables — please do a cleanup pass (e.g. system_prompt, mrope_interleaved).

if not source:
# if the source is None, skip it
continue

if not images:
text_messages = []
convroles = ["user", "assistant"]
for j, sentence in enumerate(source):
role = sentence["role"]
assert role == convroles[j % 2], f"unexpected role {role}"
text_messages.append({"role": role, "content": sentence["content"]})
conversation = processor.apply_chat_template(
text_messages,
tokenize=False,
add_generation_prompt=False,
)
encoding = processor(
text=[conversation],
max_length=max_length,
truncation=True,
return_tensors="pt",
return_offsets_mapping=True,
add_special_tokens=False,
)

input_ids = encoding.input_ids[0]
offsets = encoding.offset_mapping[0]

# get conversation with image info for loss mask generation
decoded_conversation = processor.tokenizer.decode(
encoding.input_ids[0], skip_special_tokens=False
)

# Apply loss mask
loss_mask = _apply_loss_mask_from_chat_template(
decoded_conversation, offsets, chat_template
)
results["input_ids"].append(input_ids[None, :])
results["loss_mask"].append(loss_mask[None, :])
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
results["pixel_values"].append(torch.empty(0, 0).float())
results["image_grid_thw"].append([])
continue

if source[0]["role"] != "user":
# if the first message is not from user, skip it
source = source[1:]

convroles = ["user", "assistant"]
has_added_images = False
for j, sentence in enumerate(source):
role = sentence["role"]
assert role == convroles[j % 2], f"unexpected role {role}"
if role == "user":
# if the message is from user and has image, process the image
messages.append(
{
"role": role,
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": sentence["content"]},
],
}
)
# Insert all images into the first user message
if not has_added_images:
content = [{"type": "image", "image": img} for img in images]
content.append({"type": "text", "text": sentence["content"]})
messages.append({"role": role, "content": content})
has_added_images = True
else:
messages.append({"role": role, "content": sentence["content"]})
else:
messages.append({"role": role, "content": sentence["content"]})

Expand Down Expand Up @@ -273,7 +306,7 @@ def preprocess_vlm_conversations(
input_ids = encoding.input_ids[0]
offsets = encoding.offset_mapping[0]
pixel_values = encoding.pixel_values
image_grid_thw = encoding.image_grid_thw[0]
image_grid_thw = encoding.image_grid_thw # shape: (num_images, 3)

# get conversation with image info for loss mask generation
decoded_conversation = processor.tokenizer.decode(
Expand All @@ -289,7 +322,7 @@ def preprocess_vlm_conversations(
results["loss_mask"].append(loss_mask[None, :])
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
results["pixel_values"].append(pixel_values)
results["image_grid_thw"].append(image_grid_thw[None, :])
results["image_grid_thw"].append(image_grid_thw)
return results


Expand Down Expand Up @@ -390,7 +423,7 @@ def preprocess_function(examples):
# Parse tools: handle JSON strings from safe_conversations_generator
tools = []
for tool_item in tools_raw:
if isinstance(tool_item, (str, list)):
if isinstance(tool_item, str):
try:
tools.append(json.loads(tool_item))
except json.JSONDecodeError:
Expand Down
48 changes: 29 additions & 19 deletions specforge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
- attention_mask: torch.Tensor of shape (B, N)
- loss_mask: torch.Tensor of shape (B, N)
"""
assert len(features) == 1, (
f"VlmDataCollatorWithPadding requires batch_size=1, got {len(features)}. "
"Set per_device_train_batch_size=1 in your training config."
)
max_length = max(item["input_ids"].shape[1] for item in features)
batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
Expand All @@ -218,12 +222,30 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch_loss_mask = torch.cat(
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
)
batch_pixel_values = torch.cat(
[item["pixel_values"] for item in features], dim=0
)
batch_image_grid_thw = torch.cat(
[item["image_grid_thw"] for item in features], dim=0
)
# Collect pixel_values and image_grid_thw per sample.
# Image samples have non-empty pixel_values; text-only samples have empty tensors.
all_pixel_values = []
all_image_grid_thw = []
for item in features:
pv = item.get("pixel_values")
thw = item.get("image_grid_thw")
if pv is not None and isinstance(pv, torch.Tensor) and pv.numel() > 0:
all_pixel_values.append(pv)
all_image_grid_thw.append(thw)
else:
all_image_grid_thw.append(None)

if all_pixel_values:
batch_pixel_values = torch.cat(all_pixel_values, dim=0)
else:
batch_pixel_values = None

# If all samples are text-only, set image_grid_thw to None
if all(thw is None for thw in all_image_grid_thw):
batch_image_grid_thw = None
else:
batch_image_grid_thw = all_image_grid_thw

batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
Expand Down Expand Up @@ -304,17 +326,10 @@ def prepare_dp_dataloaders(


def parse_harmony_message_content(content):
"""
解析 content 字符串中的 Harmony 格式。
如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表;
否则,返回原内容并标记为默认 channel。
"""
# 匹配 <|channel|>xxx<|message|>yyy<|end|>
pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>"
matches = re.findall(pattern, content, re.DOTALL)

if not matches:
# 如果没有匹配到 Harmony 标签,视作普通文本
return [{"channel": "text", "content": content}]

results = []
Expand All @@ -324,22 +339,17 @@ def parse_harmony_message_content(content):


def process_harmony_conversations(conversation):
"""
处理传入的 list[list[dict]] 结构
"""
new_conversation = []
for msg in conversation:
role = msg.get("role")
original_content = msg.get("content", "")

# 解析 content 中的 Harmony 结构
segments = parse_harmony_message_content(original_content)

# 为每个解析出的通道生成一个新的消息字典
for seg in segments:
new_msg = {
"role": role,
"channel": seg["channel"], # 新增字段标识通道
"channel": seg["channel"],
"content": seg["content"],
}
new_conversation.append(new_msg)
Expand Down
Loading
Loading