3333 Qwen3VLVisionRotaryEmbedding ,
3434)
3535
36- input_t = Tuple [torch .Tensor , ...]
36+ input_t = Tuple [torch .Tensor | int , ...]
3737
3838
3939def _make_qwen3_vl_2b_instruct_layer_config ():
@@ -99,6 +99,19 @@ def prepare_model_and_inputs(cls):
9999 raise NotImplementedError
100100
101101
102+ def _to_bfloat16 (
103+ model : torch .nn .Module , inputs : input_t
104+ ) -> tuple [torch .nn .Module , input_t ]:
105+ return model .to (torch .bfloat16 ), tuple (
106+ (
107+ x .to (torch .bfloat16 )
108+ if isinstance (x , torch .Tensor ) and x .is_floating_point ()
109+ else x
110+ )
111+ for x in inputs
112+ )
113+
114+
102115class Qwen3VLVisionMLPModel (Qwen3VLTestModule ):
103116 def __init__ (self , config ) -> None :
104117 super ().__init__ ()
@@ -442,6 +455,18 @@ class Qwen3VLTestCase:
442455
443456VGF_NO_QUANT_TEST_CASES : dict [str , Qwen3VLTestCase ] = TOSA_FP_TEST_CASES
444457
458+ TOSA_BF16_TEST_CASES : dict [str , Qwen3VLTestCase ] = {
459+ "vision_mlp" : TOSA_FP_TEST_CASES ["vision_mlp" ],
460+ "vision_patch_embed" : TOSA_FP_TEST_CASES ["vision_patch_embed" ],
461+ "vision_rotary_embedding" : TOSA_FP_TEST_CASES ["vision_rotary_embedding" ],
462+ "vision_rotary_apply" : TOSA_FP_TEST_CASES ["vision_rotary_apply" ],
463+ "vision_attention" : TOSA_FP_TEST_CASES ["vision_attention" ],
464+ "vision_block" : TOSA_FP_TEST_CASES ["vision_block" ],
465+ "vision_patch_merger" : TOSA_FP_TEST_CASES ["vision_patch_merger" ],
466+ "text_rms_norm" : TOSA_FP_TEST_CASES ["text_rms_norm" ],
467+ "qk_norm" : TOSA_FP_TEST_CASES ["qk_norm" ],
468+ }
469+
445470
446471@common .parametrize (
447472 "test_case" ,
@@ -460,6 +485,27 @@ def test_qwen3_vl_tosa_FP(test_case: Qwen3VLTestCase):
460485 pipeline .run ()
461486
462487
488+ @common .parametrize (
489+ "test_case" ,
490+ TOSA_BF16_TEST_CASES ,
491+ )
492+ def test_qwen3_vl_tosa_FP_bf16 (test_case : Qwen3VLTestCase ):
493+ model , inputs = test_case .model_cls .prepare_model_and_inputs ()
494+ model , inputs = _to_bfloat16 (model , inputs )
495+ with torch .no_grad ():
496+ pipeline = TosaPipelineFP [input_t ](
497+ model ,
498+ inputs ,
499+ aten_op = [],
500+ exir_op = [],
501+ transform_passes = list (test_case .transform_passes ),
502+ tosa_extensions = ["bf16" ],
503+ atol = 1e-2 ,
504+ rtol = 1e-2 ,
505+ )
506+ pipeline .run ()
507+
508+
463509@common .SkipIfNoModelConverter
464510@common .parametrize (
465511 "test_case" ,
0 commit comments