Skip to content

Commit 0725bc8

Browse files
authored
fix bug in qwen3_vl eagle3 speculative decoding (Tencent#176)
1 parent a81a0e9 commit 0725bc8

2 files changed

Lines changed: 42 additions & 18 deletions

File tree

angelslim/compressor/speculative/train/models/target/target_model_wrapper.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,14 @@ class VLMTransformersBackend(BaseBackend):
280280
def load_model(self):
281281
from transformers import AutoModelForImageTextToText, AutoProcessor
282282

283-
default_kwargs = {
284-
"dtype": torch.bfloat16,
285-
"device_map": "auto",
286-
"trust_remote_code": True,
287-
}
288-
default_kwargs.update(self.kwargs)
283+
device = decide_device_for_distributed()
284+
print_with_rank(f"Loading model to device: {device}")
285+
286+
# Prepare model loading configuration
287+
model_kwargs = self._prepare_model_kwargs(device)
289288

290289
self.model = AutoModelForImageTextToText.from_pretrained(
291-
self.model_path, **default_kwargs
290+
self.model_path, **model_kwargs
292291
)
293292

294293
# Freeze the base model
@@ -300,6 +299,24 @@ def load_model(self):
300299
self.model_path, trust_remote_code=True
301300
)
302301

302+
def _prepare_model_kwargs(self, device: str) -> dict:
303+
"""
304+
Prepare keyword arguments for model loading.
305+
306+
Args:
307+
device: Target device for model placement
308+
309+
Returns:
310+
Dictionary of model loading arguments
311+
"""
312+
default_kwargs = {
313+
"dtype": torch.bfloat16,
314+
"device_map": device,
315+
"trust_remote_code": True,
316+
}
317+
default_kwargs.update(self.kwargs)
318+
return default_kwargs
319+
303320
def get_hidden_states_and_logits(
304321
self,
305322
input_ids: torch.Tensor,
@@ -317,6 +334,12 @@ def get_hidden_states_and_logits(
317334
Returns:
318335
Tuple of (concatenated_hidden_states, logits)
319336
"""
337+
pixel_values = None
338+
image_grid_thw = None
339+
if "pixel_values" in kwargs:
340+
pixel_values = kwargs["pixel_values"].squeeze(0)
341+
if "image_grid_thw" in kwargs:
342+
image_grid_thw = kwargs["image_grid_thw"].squeeze(0)
320343
inputs_embeds_list, position_ids_list = [], []
321344

322345
def hook(module, args, kwargs):
@@ -336,6 +359,8 @@ def hook(module, args, kwargs):
336359
outputs = self.model(
337360
input_ids,
338361
attention_mask=attention_mask,
362+
pixel_values=pixel_values,
363+
image_grid_thw=image_grid_thw,
339364
output_hidden_states=True,
340365
output_logits=True,
341366
)
@@ -375,6 +400,12 @@ def get_aux_and_target_hiddens(
375400
Returns:
376401
Tuple of (auxiliary_hidden_states, final_hidden_states)
377402
"""
403+
pixel_values = None
404+
image_grid_thw = None
405+
if "pixel_values" in kwargs:
406+
pixel_values = kwargs["pixel_values"].squeeze(0)
407+
if "image_grid_thw" in kwargs:
408+
image_grid_thw = kwargs["image_grid_thw"].squeeze(0)
378409
inputs_embeds_list, position_ids_list = [], []
379410

380411
def hook(module, args, kwargs):
@@ -393,6 +424,8 @@ def hook(module, args, kwargs):
393424
with torch.no_grad():
394425
outputs = self.model(
395426
input_ids,
427+
pixel_values=pixel_values,
428+
image_grid_thw=image_grid_thw,
396429
attention_mask=attention_mask,
397430
output_hidden_states=True,
398431
output_logits=True,

angelslim/compressor/speculative/train/trainer/online_vlm_eagle3_trainer.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,12 @@ def draft_model_training_time_test(
150150
# Iterative speculative decoding training loop
151151
for idx in range(self.length):
152152
# Get input embeddings with gradient tracking
153-
if inputs_embeds is None:
154-
inputs_embeds = self.draft_model.get_input_embeddings(input_ids)
153+
inputs_embeds = self.draft_model.get_input_embeddings(input_ids)
155154
if not inputs_embeds.requires_grad:
156155
inputs_embeds.requires_grad = True
157156

158157
# Encode through draft model layers
159-
hidden_states = self.draft_model.encode_layers(
158+
hidden_states, cache_hidden = self.draft_model.encode_layers(
160159
inputs_embeds=inputs_embeds,
161160
hidden_states=hidden_states,
162161
cache_hidden=cache_hidden,
@@ -198,14 +197,6 @@ def draft_model_training_time_test(
198197
target_logits = padding(target_logits, left=False)
199198
loss_mask = padding(loss_mask, left=False)
200199

201-
# Update attention mask to prevent attending to future positions
202-
ind = torch.arange(seq_length, device=attention_mask.device)
203-
new_attention_mask = attention_mask.clone()
204-
new_attention_mask[:, :, ind[idx:], ind[: seq_length - idx]] = (
205-
torch.finfo(attention_mask.dtype).min
206-
)
207-
attention_mask = new_attention_mask
208-
209200
# Compute weighted loss
210201
ploss_weight = [0.8**i for i in range(len(plosses))]
211202
ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])

0 commit comments

Comments
 (0)