feat: modify kwai template#9117
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Kwai template to support slow-fast multimodal inputs and modularizes visual embedding logic. It also integrates RoPE index calculation into the post-encoding phase. Review feedback indicates that the early return in _post_encode may incorrectly bypass necessary embedding and position ID logic during inference. A suggestion was also made to use a more robust method for calculating dimension products to avoid floating-point conversions.
| if not self.is_training: | ||
| return inputs |
There was a problem hiding this comment.
The early return if not self.is_training prevents the computation of inputs_embeds and the new position_ids (via get_rope_index_slowfast) during inference when using the transformers backend. For multimodal models like KeyeVL, replacing token embeddings with visual embeddings and using the correct RoPE indices is essential for correct inference results. Since this PR aims to fix the RoPE implementation, this logic should also be applied during inference.
| cu_seqlens = [0] | ||
| for idx, thw_tuple in enumerate(grid_hws): | ||
| numel = int(np.prod(thw_tuple)) | ||
| media_position_ids = torch.arange(numel, device=device) % int(np.prod(thw_tuple[1:])) |
There was a problem hiding this comment.
Using np.prod on a slice of a tuple returns a float in many NumPy versions. While it is cast to int here, using math.prod (available in Python 3.8+) is generally more idiomatic for calculating the product of dimensions in a shape tuple and avoids floating-point conversions.
| media_position_ids = torch.arange(numel, device=device) % int(np.prod(thw_tuple[1:])) | |
| media_position_ids = torch.arange(numel, device=device) % int(torch.prod(torch.tensor(thw_tuple[1:]))) |
|
thanks! please pass lint test |
PR type
PR information
Fix the
KeyeVLTemplateimplementation with the following changes: