Skip to content

Commit e93dc4e

Browse files
committed
dev
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 54f99f7 commit e93dc4e

3 files changed

Lines changed: 49 additions & 22 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def train():
170170
# To avoid OOM for large models, we load and convert model on CPU first.
171171
# Model will be moved to GPU during HF trainer.init().
172172
offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
173-
model = transformers.Qwen3VLMoeForConditionalGeneration.from_pretrained(
173+
model = transformers.Qwen3VLForConditionalGeneration.from_pretrained(
174174
model_args.model_name_or_path,
175175
torch_dtype="auto",
176176
device_map="cpu",

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,16 +425,23 @@ def _base_model_lm_head(self):
425425
@property
426426
def _base_llm_config(self):
427427
"""Return the llm config for the base model, from LLM or VLM."""
428-
return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
428+
# return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
429+
return self.config.text_config
429430

430431
def _find_base_model_parts(self):
431432
"""Find model parts from different models and set base_{part}_path attributes."""
432433
base_model_parts_mapping = {
433-
"base_model_path": ["model", "backbone", "language_model.backbone"],
434+
"base_model_path": [
435+
"model.language_model",
436+
"model",
437+
"backbone",
438+
"language_model.backbone",
439+
],
434440
"base_model_embeddings_path": [
435441
"model.embed_tokens",
436442
"backbone.embeddings",
437443
"language_model.backbone.embeddings",
444+
"model.language_model.embed_tokens",
438445
],
439446
"base_model_lm_head_path": ["lm_head", "language_model.lm_head"],
440447
}
@@ -747,7 +754,8 @@ def _llm_or_vlm_embedding(self, input_ids, kwargs):
747754
del vit_embeds
748755
return tok_embeds.reshape(bs, seq_len, hid_size)
749756
else:
750-
raise ValueError(f"VLM model type {self.config.model_type} not supported")
757+
breakpoint()
758+
# raise ValueError(f"VLM model type {self.config.model_type} not supported")
751759

752760
def _base_model_forward(
753761
self,
@@ -769,6 +777,7 @@ def _base_model_forward(
769777
**kwargs,
770778
)
771779
past_key_values = getattr(outputs, "past_key_values", None)
780+
input_embeds = outputs.hidden_states[0]
772781
base_model_hidden_states = outputs.hidden_states[-1]
773782
base_model_logits = outputs.logits
774783

@@ -780,7 +789,13 @@ def _base_model_forward(
780789
labels = labels.view(-1)
781790
base_model_loss = loss_fct(loss_logits, labels)
782791

783-
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values
792+
return (
793+
input_embeds,
794+
base_model_hidden_states,
795+
base_model_logits,
796+
base_model_loss,
797+
past_key_values,
798+
)
784799

785800
def _map_logits_to_draft_vocab(self, full_logits):
786801
reverse_mapping = (
@@ -872,16 +887,20 @@ def forward(
872887
base_model_logits = self.lm_head(base_model_hidden_states)
873888
base_model_loss, past_key_values = None, None
874889
else:
875-
base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
876-
self._base_model_forward(
877-
input_ids,
878-
attention_mask,
879-
position_ids,
880-
past_key_values,
881-
self.eagle_freeze_base_model,
882-
labels,
883-
**kwargs,
884-
)
890+
(
891+
base_input_embeds,
892+
base_model_hidden_states,
893+
base_model_logits,
894+
base_model_loss,
895+
past_key_values,
896+
) = self._base_model_forward(
897+
input_ids,
898+
attention_mask,
899+
position_ids,
900+
past_key_values,
901+
self.eagle_freeze_base_model,
902+
labels,
903+
**kwargs,
885904
)
886905

887906
if not isinstance(past_key_values, Cache):
@@ -912,7 +931,8 @@ def forward(
912931
eagle_cache,
913932
)
914933
with torch.no_grad():
915-
inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)
934+
# inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)
935+
inputs_embeds = base_input_embeds.roll(-1, 1)
916936

917937
past_key_values.eagle_cache = eagle_cache
918938

modelopt/torch/utils/plugins/transformers_dataset.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import copy
1919
import itertools
20+
import os
2021

2122
import torch
2223
import transformers
@@ -44,8 +45,8 @@ def _sharegpt_to_openai_messages(conversations: list[dict]):
4445
}
4546
messages = []
4647
for msg in conversations:
47-
role = role_mapping[msg["from"]]
48-
content = msg["value"]
48+
role = role_mapping[msg["role"]]
49+
content = msg["content"]
4950
messages.append({"role": role, "content": content})
5051
return messages
5152

@@ -225,7 +226,7 @@ def __init__(
225226
chat_template: str | None = None,
226227
add_generation_prompt: bool = False,
227228
answer_only_loss: bool = False,
228-
local_image_path: str | None = None,
229+
local_image_path: str = "",
229230
return_labels: bool = False,
230231
):
231232
"""Initialize the VisionLanguageDataset."""
@@ -242,8 +243,6 @@ def __init__(
242243
)
243244

244245
def _process_multimodal_sample(self, examples):
245-
print(examples)
246-
breakpoint()
247246
tokenized_messages = self.processor.apply_chat_template(
248247
examples,
249248
tokenize=True,
@@ -279,9 +278,17 @@ def __call__(self, examples):
279278
for msg in copy_messages:
280279
if isinstance(msg["content"], str):
281280
msg["content"] = [{"type": "text", "text": msg["content"]}]
281+
282282
for ctn in msg["content"]:
283283
if ctn["type"] == "image" and "image" in ctn:
284-
ctn["image"] = self.local_image_path + "/" + ctn["image"]
284+
ctn["image"] = os.path.abspath(
285+
os.path.join(self.local_image_path, ctn["image"])
286+
)
287+
# If any value in ctn is None, delete that key
288+
# HF dataloader add Nones to align keys. Leads to error in processor.
289+
keys_to_delete = [k for k, v in ctn.items() if v is None]
290+
for k in keys_to_delete:
291+
del ctn[k]
285292

286293
batch.append(copy_messages)
287294

0 commit comments

Comments
 (0)