@@ -677,157 +677,164 @@ def _get_seqlen(cls) -> torch.Tensor:
677677 @requires_cuda ()
678678 def test_plug_multi_head_attention_qwen25_packed_float16 (self ):
679679 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
680- qwen_sdpa_attention_packed_versatile ,
680+ qwen_sdpa_attention_versatile as qwen_sdpa_attention_packed_versatile ,
681681 )
682682
683- inputs = (
684- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
685- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
686- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
687- self ._get_seqlen ().to ("cuda" ),
688- )
683+ with self .set_env ("QWEN25ATTENTION" , "PACKED" ):
684+ inputs = (
685+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
686+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
687+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
688+ self ._get_seqlen ().to ("cuda" ),
689+ )
689690
690- results = qwen_sdpa_attention_packed_versatile .verify (
691- * inputs , scaling = 0.5 , num_heads = 16
692- )
693- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
694- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
695- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
696- self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
691+ results = qwen_sdpa_attention_packed_versatile .verify (
692+ * inputs , scaling = 0.5 , num_heads = 16
693+ )
694+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
695+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
696+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
697+ self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
697698
698- results = qwen_sdpa_attention_packed_versatile .verify (
699- * inputs , scaling = 0.11180339887498948 , num_heads = 16
700- )
701- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
702- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
703- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
704- self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
699+ results = qwen_sdpa_attention_packed_versatile .verify (
700+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
701+ )
702+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
703+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
704+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
705+ self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
705706
706707 @requires_onnxruntime ("1.25" )
707708 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
708709 def test_plug_multi_head_attention_qwen25_loopmha_float16 (self ):
709710 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
710- qwen_sdpa_attention_loopmha_versatile ,
711+ qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile ,
711712 )
712713
713- inputs = (
714- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
715- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
716- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
717- self ._get_seqlen (),
718- )
714+ with self .set_env ("QWEN25ATTENTION" , "LOOPMHA" ):
715+ inputs = (
716+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
717+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
718+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
719+ self ._get_seqlen (),
720+ )
719721
720- results = qwen_sdpa_attention_loopmha_versatile .verify (
721- * inputs ,
722- scaling = 0.5 ,
723- num_heads = 16 ,
724- dump_onnx_model = self .get_dump_file (
725- "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
726- ),
727- )
728- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
729- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
730- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
731- self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
722+ results = qwen_sdpa_attention_loopmha_versatile .verify (
723+ * inputs ,
724+ scaling = 0.5 ,
725+ num_heads = 16 ,
726+ dump_onnx_model = self .get_dump_file (
727+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
728+ ),
729+ )
730+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
731+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
732+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
733+ self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
732734
733- results = qwen_sdpa_attention_loopmha_versatile .verify (
734- * inputs , scaling = 0.11180339887498948 , num_heads = 16
735- )
736- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
737- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
738- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
739- self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
735+ results = qwen_sdpa_attention_loopmha_versatile .verify (
736+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
737+ )
738+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
739+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
740+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
741+ self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
740742
741743 @requires_onnxruntime ("1.25" )
742744 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
743745 def test_plug_multi_head_attention_qwen25_loopmha_float32 (self ):
744746 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
745- qwen_sdpa_attention_loopmha_versatile ,
747+ qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile ,
746748 )
747749
748- inputs = (
749- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
750- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
751- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
752- self ._get_seqlen (),
753- )
750+ with self .set_env ("QWEN25ATTENTION" , "LOOPMHA" ):
751+ inputs = (
752+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
753+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
754+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
755+ self ._get_seqlen (),
756+ )
754757
755- results = qwen_sdpa_attention_loopmha_versatile .verify (
756- * inputs ,
757- scaling = 0.5 ,
758- num_heads = 16 ,
759- dump_onnx_model = self .get_dump_file (
760- "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
761- ),
762- )
763- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
764- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
765- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
766- self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
758+ results = qwen_sdpa_attention_loopmha_versatile .verify (
759+ * inputs ,
760+ scaling = 0.5 ,
761+ num_heads = 16 ,
762+ dump_onnx_model = self .get_dump_file (
763+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
764+ ),
765+ )
766+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
767+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
768+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
769+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
767770
768- results = qwen_sdpa_attention_loopmha_versatile .verify (
769- * inputs , scaling = 0.11180339887498948 , num_heads = 16
770- )
771- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
772- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
773- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
774- self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
771+ results = qwen_sdpa_attention_loopmha_versatile .verify (
772+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
773+ )
774+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
775+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
776+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
777+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
775778
776779 @requires_onnxruntime ("1.25" )
777780 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
778781 def test_plug_multi_head_attention_qwen25_loopa24_float16 (self ):
779782 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
780- qwen_sdpa_attention_loopa24_versatile ,
783+ qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile ,
781784 )
782785
783- inputs = (
784- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
785- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
786- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
787- self ._get_seqlen (),
788- )
786+ with self .set_env ("QWEN25ATTENTION" , "LOOO24" ):
787+ inputs = (
788+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
789+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
790+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
791+ self ._get_seqlen (),
792+ )
789793
790- results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
791- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
792- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
793- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-2 )
794- self .assertLess (results .diffs [0 ]["abs" ], 1e-2 )
794+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
795+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
796+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
797+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-2 )
798+ self .assertLess (results .diffs [0 ]["abs" ], 1e-2 )
795799
796- results = qwen_sdpa_attention_loopa24_versatile .verify (
797- * inputs , scaling = 0.11180339887498948
798- )
799- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
800- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
801- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.005 )
802- self .assertLess (results .diffs [0 ]["abs" ], 0.005 )
800+ results = qwen_sdpa_attention_loopa24_versatile .verify (
801+ * inputs , scaling = 0.11180339887498948
802+ )
803+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
804+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
805+ self .assertEqualArray (
806+ results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.005
807+ )
808+ self .assertLess (results .diffs [0 ]["abs" ], 0.005 )
803809
804810 @requires_onnxruntime ("1.25" )
805811 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
806812 def test_plug_multi_head_attention_qwen25_loopa24_float32 (self ):
807813 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
808- qwen_sdpa_attention_loopa24_versatile ,
814+ qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile ,
809815 )
810816
811- inputs = (
812- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
813- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
814- torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
815- self ._get_seqlen (),
816- )
817+ with self .set_env ("QWEN25ATTENTION" , "LOOO24" ):
818+ inputs = (
819+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
820+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
821+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
822+ self ._get_seqlen (),
823+ )
817824
818- results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
819- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
820- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
821- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
822- self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
825+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
826+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
827+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
828+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
829+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
823830
824- results = qwen_sdpa_attention_loopa24_versatile .verify (
825- * inputs , scaling = 0.11180339887498948
826- )
827- self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
828- self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
829- self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
830- self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
831+ results = qwen_sdpa_attention_loopa24_versatile .verify (
832+ * inputs , scaling = 0.11180339887498948
833+ )
834+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
835+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
836+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
837+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
831838
832839 @unittest .skipIf (not patch_funnel , "Funnel not part of this transformers" )
833840 def test_model_funnel (self ):
0 commit comments