Skip to content

Commit c17f6c2

Browse files
yueshen2016danielkorzekwa
authored andcommitted
Support megatron generate for vlm (#773)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** ? This PR adds feature of VLM generation for megatron_generate ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Vision Language Model support to text generation pipeline, enabling simultaneous processing of image and text inputs during both generation and prefill operations. * **Improvements** * Enhanced data flow to properly route multimodal inputs (images and text tokens) through generation paths with automatic detection and handling of vision-enabled model architectures. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent dedd0a0 commit c17f6c2

1 file changed

Lines changed: 52 additions & 8 deletions

File tree

modelopt/torch/utils/plugins/megatron_generate.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,48 @@ def _forward_step_func(data, model):
210210
# NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported.
211211
position_ids = None
212212

213-
output_tensor = model(
214-
data["tokens"],
215-
position_ids,
216-
attention_mask,
217-
inference_context=inference_context,
218-
runtime_gather_output=True,
219-
)
213+
# Check if this is a VLM model (has vision inputs)
214+
_has_pixel_values = data.get("pixel_values") is not None
215+
_has_image_grid_thw = data.get("image_grid_thw") is not None
216+
_has_image_sizes = data.get("image_sizes") is not None
217+
has_vision_inputs = _has_pixel_values or _has_image_grid_thw or _has_image_sizes
218+
219+
if has_vision_inputs:
220+
# For VLM models:
221+
# - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions)
222+
# - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal)
223+
vlm_position_ids = (
224+
torch.arange(seq_len, dtype=torch.long, device=device)
225+
.unsqueeze(0)
226+
.expand(batch_size, -1)
227+
)
228+
vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
229+
230+
forward_args = {
231+
"input_ids": data["tokens"],
232+
"position_ids": vlm_position_ids,
233+
"attention_mask": vlm_attention_mask,
234+
"inference_context": inference_context,
235+
"runtime_gather_output": True,
236+
}
237+
# Add vision inputs
238+
if _has_pixel_values:
239+
forward_args["pixel_values"] = data["pixel_values"]
240+
if _has_image_grid_thw:
241+
forward_args["image_grid_thw"] = data["image_grid_thw"]
242+
if _has_image_sizes:
243+
forward_args["image_sizes"] = data["image_sizes"]
244+
245+
output_tensor = model(**forward_args)
246+
else:
247+
# For text-only LLM models
248+
output_tensor = model(
249+
data["tokens"],
250+
position_ids,
251+
attention_mask,
252+
inference_context=inference_context,
253+
runtime_gather_output=True,
254+
)
220255
return output_tensor, _dummy_loss_func
221256

222257
disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0
@@ -250,9 +285,18 @@ def _forward_step_func(data, model):
250285
else:
251286
tokens = input_ids
252287

288+
data_dict = {"tokens": tokens}
289+
# Vision inputs should only be passed during prefill (step 0), not during decode steps
290+
if pixel_values is not None:
291+
data_dict["pixel_values"] = pixel_values
292+
if image_grid_thw is not None:
293+
data_dict["image_grid_thw"] = image_grid_thw
294+
if image_sizes is not None:
295+
data_dict["image_sizes"] = image_sizes
296+
253297
list_of_logits = get_forward_backward_func()(
254298
forward_step_func=_forward_step_func,
255-
data_iterator=[{"tokens": tokens}],
299+
data_iterator=[data_dict],
256300
model=model,
257301
num_microbatches=1,
258302
seq_length=tokens.shape[-1],

0 commit comments

Comments
 (0)