Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 89 additions & 19 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,9 @@ def apply_liger_kernel_to_qwen3_vl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
layer_norm: bool = True,
swiglu: bool = True,
geglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Expand All @@ -1663,7 +1665,9 @@ def apply_liger_kernel_to_qwen3_vl(
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP for text model. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP for vision model. Default is False.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
Expand All @@ -1676,16 +1680,20 @@ def apply_liger_kernel_to_qwen3_vl(
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward

if rope:
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch

if rms_norm:
if rms_norm and model is None:
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm

if layer_norm and model is None:
modeling_qwen3_vl.nn.LayerNorm = LigerLayerNorm

if cross_entropy:
from transformers.loss.loss_utils import nn

Expand All @@ -1697,37 +1705,65 @@ def apply_liger_kernel_to_qwen3_vl(
else:
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward

if model is not None and rms_norm:
if swiglu:
modeling_qwen3_vl.Qwen3VLTextMLP = LigerSwiGLUMLP

if geglu:
logger.warning(
"geglu is set to True, noting that there might be numerical differences compared to the original model. "
"Check https://github.com/linkedin/Liger-Kernel/issues/959"
)
modeling_qwen3_vl.Qwen3VLVisionMLP = LigerGEGLUMLP

if model is not None:
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
text_model: Qwen3VLTextModel = model.language_model
vision_model: Qwen3VLVisionModel = model.visual
elif isinstance(model, Qwen3VLTextModel):
text_model = model
text_model: Qwen3VLTextModel = model
vision_model = None
else:
raise TypeError(
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
)

_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
_patch_qwen3_vl_qk_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama", row_mode=True)

if text_model is not None:
_patch_qwen3_vl_rms_norm(text_model.norm)
if rms_norm:
_patch_qwen3_vl_rms_norm(text_model.norm)
for decoder_layer in text_model.layers:
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
self_attn = getattr(decoder_layer, "self_attn", None)
if self_attn is not None:
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
if rms_norm:
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
self_attn = getattr(decoder_layer, "self_attn", None)
if self_attn is not None:
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
_patch_qwen3_vl_qk_norm(self_attn.q_norm)
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
_patch_qwen3_vl_qk_norm(self_attn.k_norm)

if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)

if vision_model is not None:
for vision_block in vision_model.blocks:
if layer_norm:
_patch_layer_norm_module(vision_block.norm1)
_patch_layer_norm_module(vision_block.norm2)
if geglu:
_patch_geglu_module(vision_block.mlp)


def apply_liger_kernel_to_qwen3_vl_moe(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
layer_norm: bool = True,
swiglu: bool = True,
geglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Expand All @@ -1738,7 +1774,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP for text model. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP for vision model. Default is False.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
Expand All @@ -1750,7 +1788,10 @@ def apply_liger_kernel_to_qwen3_vl_moe(
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextMLP
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel

from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward

Expand All @@ -1772,29 +1813,58 @@ def apply_liger_kernel_to_qwen3_vl_moe(
else:
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward

if swiglu:
modeling_qwen3_vl_moe.Qwen3VLMoeTextMLP = LigerSwiGLUMLP
if geglu:
logger.warning(
"geglu is set to True, there might be numerical differences compared to the original model. "
"Check https://github.com/linkedin/Liger-Kernel/issues/959"
)
modeling_qwen3_vl_moe.Qwen3VLMoeVisionMLP = LigerGEGLUMLP

if model is not None and rms_norm:
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
text_model: Qwen3VLMoeTextModel = model.language_model
vision_model: Qwen3VLMoeVisionModel = model.visual
elif isinstance(model, Qwen3VLMoeTextModel):
text_model = model
vision_model = None
else:
raise TypeError(
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
)

_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
_patch_qwen3_vl_moe_qk_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama", row_mode=True)

if text_model is not None:
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
if rms_norm:
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
for decoder_layer in text_model.layers:
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
self_attn = getattr(decoder_layer, "self_attn", None)
if self_attn is not None:
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
_patch_qwen3_vl_moe_qk_norm(self_attn.q_norm)
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
_patch_qwen3_vl_moe_qk_norm(self_attn.k_norm)
if swiglu:
if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock):
# TODO(xxx): Implement LigerMoe for MoE sparse block for trasnformers v5
logger.warning(
"Skipping MLP patching for Qwen3VLMoeTextSparseMoeBlock. There will be a breaking change in transformers v5"
)
elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP):
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)

if vision_model is not None:
for vision_block in vision_model.blocks:
if layer_norm:
_patch_layer_norm_module(vision_block.norm1)
_patch_layer_norm_module(vision_block.norm2)
if geglu:
_patch_geglu_module(vision_block.mlp)


def apply_liger_kernel_to_phi3(
Expand Down
2 changes: 1 addition & 1 deletion test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def run_mini_model(
if "glm4" in model_name or "qwen3_next" in model_name:
kwargs["rope"] = False

model_supports_layer_norm = "qwen2_vl" in model_name
model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

Expand Down
2 changes: 1 addition & 1 deletion test/convergence/bf16/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def run_mini_model_multimodal(
"cross_entropy": False,
}

if "qwen2_5_vl" not in model_name and "llava" not in model_name and "qwen3_vl" not in model_name:
if "qwen2_5_vl" not in model_name and "llava" not in model_name:
kwargs["layer_norm"] = True

if "gemma" in model_name:
Expand Down
2 changes: 1 addition & 1 deletion test/convergence/bf16/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,7 @@ def run_mini_model(
if "glm4" in model_name or "llama4" in model_name or "qwen3_next" in model_name:
kwargs["rope"] = False

model_supports_layer_norm = "qwen2_vl" in model_name
model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

Expand Down
2 changes: 1 addition & 1 deletion test/convergence/fp32/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ def run_mini_model(
if "glm4" in model_name or "qwen3_next" in model_name:
kwargs["rope"] = False

model_supports_layer_norm = "qwen2_vl" in model_name
model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

Expand Down
2 changes: 1 addition & 1 deletion test/convergence/fp32/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def run_mini_model_multimodal(
}
if "llama4" in model_name:
kwargs["rope"] = False
if "qwen2_5_vl" not in model_name and "llava" not in model_name and "qwen3_vl" not in model_name:
if "qwen2_5_vl" not in model_name and "llava" not in model_name:
kwargs["layer_norm"] = True

if "gemma" in model_name:
Expand Down
2 changes: 1 addition & 1 deletion test/convergence/fp32/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def run_mini_model(
if "glm4" in model_name or "llama4" in model_name:
kwargs["rope"] = False

model_supports_layer_norm = "qwen2_vl" in model_name
model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

Expand Down
Loading
Loading