diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index df3f74951..cdc880024 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1650,7 +1650,9 @@ def apply_liger_kernel_to_qwen3_vl( cross_entropy: bool = False, fused_linear_cross_entropy: bool = True, rms_norm: bool = True, - swiglu: bool = False, + layer_norm: bool = True, + swiglu: bool = True, + geglu: bool = False, model: PreTrainedModel = None, ) -> None: """ @@ -1663,7 +1665,9 @@ def apply_liger_kernel_to_qwen3_vl( `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. - swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP for text model. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP for vision model. Default is False. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -1676,6 +1680,7 @@ def apply_liger_kernel_to_qwen3_vl( from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward @@ -1683,9 +1688,12 @@ def apply_liger_kernel_to_qwen3_vl( modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch - if rms_norm: + if rms_norm and model is None: modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm + if layer_norm and model is None: + modeling_qwen3_vl.nn.LayerNorm = LigerLayerNorm + if cross_entropy: from transformers.loss.loss_utils import nn @@ -1697,29 +1705,55 @@ def apply_liger_kernel_to_qwen3_vl( else: modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward - if model is not None and rms_norm: + if swiglu: + modeling_qwen3_vl.Qwen3VLTextMLP = LigerSwiGLUMLP + + if geglu: + logger.warning( + "geglu is set to True, noting that there might be numerical differences compared to the original model. " + "Check https://github.com/linkedin/Liger-Kernel/issues/959" + ) + modeling_qwen3_vl.Qwen3VLVisionMLP = LigerGEGLUMLP + + if model is not None: if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)): text_model: Qwen3VLTextModel = model.language_model + vision_model: Qwen3VLVisionModel = model.visual elif isinstance(model, Qwen3VLTextModel): - text_model = model + text_model: Qwen3VLTextModel = model + vision_model = None else: raise TypeError( f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}" ) _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") + _patch_qwen3_vl_qk_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama", row_mode=True) if text_model is not None: - _patch_qwen3_vl_rms_norm(text_model.norm) + if rms_norm: + _patch_qwen3_vl_rms_norm(text_model.norm) for decoder_layer in text_model.layers: - _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) - _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) - self_attn = getattr(decoder_layer, "self_attn", None) - if self_attn is not None: - if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: - _patch_qwen3_vl_rms_norm(self_attn.q_norm) - if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: - _patch_qwen3_vl_rms_norm(self_attn.k_norm) + if rms_norm: + _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) + _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: + _patch_qwen3_vl_qk_norm(self_attn.q_norm) + if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: + _patch_qwen3_vl_qk_norm(self_attn.k_norm) + + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + + if vision_model is not None: + for vision_block in vision_model.blocks: + if layer_norm: + _patch_layer_norm_module(vision_block.norm1) + _patch_layer_norm_module(vision_block.norm2) + if geglu: + _patch_geglu_module(vision_block.mlp) def apply_liger_kernel_to_qwen3_vl_moe( @@ -1727,7 +1761,9 @@ def apply_liger_kernel_to_qwen3_vl_moe( cross_entropy: bool = False, fused_linear_cross_entropy: bool = True, rms_norm: bool = True, - swiglu: bool = False, + layer_norm: bool = True, + swiglu: bool = True, + geglu: bool = False, model: PreTrainedModel = None, ) -> None: """ @@ -1738,7 +1774,9 @@ def apply_liger_kernel_to_qwen3_vl_moe( fused_linear_cross_entropy (bool): Whether to apply Liger's fused linear cross entropy loss. Default is False. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. - swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP for text model. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP for vision model. Default is False. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -1750,7 +1788,10 @@ def apply_liger_kernel_to_qwen3_vl_moe( from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextMLP from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward @@ -1772,29 +1813,58 @@ def apply_liger_kernel_to_qwen3_vl_moe( else: modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward + if swiglu: + modeling_qwen3_vl_moe.Qwen3VLMoeTextMLP = LigerSwiGLUMLP + if geglu: + logger.warning( + "geglu is set to True, there might be numerical differences compared to the original model. " + "Check https://github.com/linkedin/Liger-Kernel/issues/959" + ) + modeling_qwen3_vl_moe.Qwen3VLMoeVisionMLP = LigerGEGLUMLP + if model is not None and rms_norm: if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)): text_model: Qwen3VLMoeTextModel = model.language_model + vision_model: Qwen3VLMoeVisionModel = model.visual elif isinstance(model, Qwen3VLMoeTextModel): text_model = model + vision_model = None else: raise TypeError( f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}" ) _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") + _patch_qwen3_vl_moe_qk_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama", row_mode=True) if text_model is not None: - _patch_qwen3_vl_moe_rms_norm(text_model.norm) + if rms_norm: + _patch_qwen3_vl_moe_rms_norm(text_model.norm) for decoder_layer in text_model.layers: _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm) _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm) self_attn = getattr(decoder_layer, "self_attn", None) if self_attn is not None: if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: - _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm) + _patch_qwen3_vl_moe_qk_norm(self_attn.q_norm) if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: - _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm) + _patch_qwen3_vl_moe_qk_norm(self_attn.k_norm) + if swiglu: + if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock): + # TODO(xxx): Implement LigerMoe for MoE sparse block for trasnformers v5 + logger.warning( + "Skipping MLP patching for Qwen3VLMoeTextSparseMoeBlock. There will be a breaking change in transformers v5" + ) + elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP): + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + + if vision_model is not None: + for vision_block in vision_model.blocks: + if layer_norm: + _patch_layer_norm_module(vision_block.norm1) + _patch_layer_norm_module(vision_block.norm2) + if geglu: + _patch_geglu_module(vision_block.mlp) def apply_liger_kernel_to_phi3( diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index af528ed1c..7a2c35f3d 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -1459,7 +1459,7 @@ def run_mini_model( if "glm4" in model_name or "qwen3_next" in model_name: kwargs["rope"] = False - model_supports_layer_norm = "qwen2_vl" in model_name + model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name if model_supports_layer_norm: kwargs["layer_norm"] = True diff --git a/test/convergence/bf16/test_mini_models_multimodal.py b/test/convergence/bf16/test_mini_models_multimodal.py index bd090e060..470658b2a 100644 --- a/test/convergence/bf16/test_mini_models_multimodal.py +++ b/test/convergence/bf16/test_mini_models_multimodal.py @@ -1150,7 +1150,7 @@ def run_mini_model_multimodal( "cross_entropy": False, } - if "qwen2_5_vl" not in model_name and "llava" not in model_name and "qwen3_vl" not in model_name: + if "qwen2_5_vl" not in model_name and "llava" not in model_name: kwargs["layer_norm"] = True if "gemma" in model_name: diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index f1d29b381..02e627556 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -1454,7 +1454,7 @@ def run_mini_model( if "glm4" in model_name or "llama4" in model_name or "qwen3_next" in model_name: kwargs["rope"] = False - model_supports_layer_norm = "qwen2_vl" in model_name + model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name if model_supports_layer_norm: kwargs["layer_norm"] = True diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index 263a07323..9cc05a578 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -1446,7 +1446,7 @@ def run_mini_model( if "glm4" in model_name or "qwen3_next" in model_name: kwargs["rope"] = False - model_supports_layer_norm = "qwen2_vl" in model_name + model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name if model_supports_layer_norm: kwargs["layer_norm"] = True diff --git a/test/convergence/fp32/test_mini_models_multimodal.py b/test/convergence/fp32/test_mini_models_multimodal.py index b7656418b..1a1f86cca 100644 --- a/test/convergence/fp32/test_mini_models_multimodal.py +++ b/test/convergence/fp32/test_mini_models_multimodal.py @@ -1288,7 +1288,7 @@ def run_mini_model_multimodal( } if "llama4" in model_name: kwargs["rope"] = False - if "qwen2_5_vl" not in model_name and "llava" not in model_name and "qwen3_vl" not in model_name: + if "qwen2_5_vl" not in model_name and "llava" not in model_name: kwargs["layer_norm"] = True if "gemma" in model_name: diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index f14d13218..87082576c 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -1466,7 +1466,7 @@ def run_mini_model( if "glm4" in model_name or "llama4" in model_name: kwargs["rope"] = False - model_supports_layer_norm = "qwen2_vl" in model_name + model_supports_layer_norm = "qwen2_vl" in model_name or "qwen3_vl" in model_name if model_supports_layer_norm: kwargs["layer_norm"] = True diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 51fa02660..6dd4acd83 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -527,6 +527,15 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -547,6 +556,15 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + try: print(dummy_model_instance) except Exception as e: @@ -628,6 +646,15 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -648,6 +675,15 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + try: print(dummy_model_instance) except Exception as e: @@ -731,6 +767,8 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe"): from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextMLP + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward @@ -782,7 +820,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat moe_intermediate_size=1024, num_experts_per_tok=2, num_experts=4, - mlp_only_layers=[], + mlp_only_layers=[0, 2], ).to_dict(), ) dummy_model_instance = Qwen3VLMoeForConditionalGeneration._from_config(config) @@ -806,6 +844,19 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock): + # TODO(xxx): Implement LigerMoe for MoE sparse block for transformers v5 + pass + elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP): + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -826,6 +877,19 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generat if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock): + # TODO(xxx): Implement LigerMoe for MoE sparse block for transformers v5 + pass + elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP): + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + try: print(dummy_model_instance) except Exception as e: @@ -837,6 +901,8 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe"): from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextMLP + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward @@ -888,7 +954,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe(): moe_intermediate_size=1024, num_experts_per_tok=2, num_experts=4, - mlp_only_layers=[], + mlp_only_layers=[0, 2], ).to_dict(), ) dummy_model_instance = Qwen3VLMoeModel._from_config(config) @@ -912,6 +978,19 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe(): if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock): + # TODO(xxx): Implement LigerMoe for MoE sparse block for transformers v5 + pass + elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP): + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -932,6 +1011,19 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe(): if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: assert inspect.getsource(self_attn.k_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + if isinstance(decoder_layer.mlp, Qwen3VLMoeTextSparseMoeBlock): + # TODO(xxx): Implement LigerMoe for MoE sparse block for transformers v5 + pass + elif isinstance(decoder_layer.mlp, Qwen3VLMoeTextMLP): + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(LigerLayerNorm.forward) + + # numerical issue with LigerGEGLUMLP, no patching check for now + # assert inspect.getsource(vision_block.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + try: print(dummy_model_instance) except Exception as e: