Skip to content

Commit ea37954

Browse files
authored
Arm backend: Add BF16 layer tests for Qwen (pytorch#19767)
* Add layers that run in BF16 in the HF model Change-Id: If75434db138059f3a433a70abda3f3e26f6dd3b6 cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani --------- Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent 501d641 commit ea37954

1 file changed

Lines changed: 47 additions & 1 deletion

File tree

backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
Qwen3VLVisionRotaryEmbedding,
3434
)
3535

36-
input_t = Tuple[torch.Tensor, ...]
36+
input_t = Tuple[torch.Tensor | int, ...]
3737

3838

3939
def _make_qwen3_vl_2b_instruct_layer_config():
@@ -99,6 +99,19 @@ def prepare_model_and_inputs(cls):
9999
raise NotImplementedError
100100

101101

102+
def _to_bfloat16(
103+
model: torch.nn.Module, inputs: input_t
104+
) -> tuple[torch.nn.Module, input_t]:
105+
return model.to(torch.bfloat16), tuple(
106+
(
107+
x.to(torch.bfloat16)
108+
if isinstance(x, torch.Tensor) and x.is_floating_point()
109+
else x
110+
)
111+
for x in inputs
112+
)
113+
114+
102115
class Qwen3VLVisionMLPModel(Qwen3VLTestModule):
103116
def __init__(self, config) -> None:
104117
super().__init__()
@@ -442,6 +455,18 @@ class Qwen3VLTestCase:
442455

443456
VGF_NO_QUANT_TEST_CASES: dict[str, Qwen3VLTestCase] = TOSA_FP_TEST_CASES
444457

458+
TOSA_BF16_TEST_CASES: dict[str, Qwen3VLTestCase] = {
459+
"vision_mlp": TOSA_FP_TEST_CASES["vision_mlp"],
460+
"vision_patch_embed": TOSA_FP_TEST_CASES["vision_patch_embed"],
461+
"vision_rotary_embedding": TOSA_FP_TEST_CASES["vision_rotary_embedding"],
462+
"vision_rotary_apply": TOSA_FP_TEST_CASES["vision_rotary_apply"],
463+
"vision_attention": TOSA_FP_TEST_CASES["vision_attention"],
464+
"vision_block": TOSA_FP_TEST_CASES["vision_block"],
465+
"vision_patch_merger": TOSA_FP_TEST_CASES["vision_patch_merger"],
466+
"text_rms_norm": TOSA_FP_TEST_CASES["text_rms_norm"],
467+
"qk_norm": TOSA_FP_TEST_CASES["qk_norm"],
468+
}
469+
445470

446471
@common.parametrize(
447472
"test_case",
@@ -460,6 +485,27 @@ def test_qwen3_vl_tosa_FP(test_case: Qwen3VLTestCase):
460485
pipeline.run()
461486

462487

488+
@common.parametrize(
489+
"test_case",
490+
TOSA_BF16_TEST_CASES,
491+
)
492+
def test_qwen3_vl_tosa_FP_bf16(test_case: Qwen3VLTestCase):
493+
model, inputs = test_case.model_cls.prepare_model_and_inputs()
494+
model, inputs = _to_bfloat16(model, inputs)
495+
with torch.no_grad():
496+
pipeline = TosaPipelineFP[input_t](
497+
model,
498+
inputs,
499+
aten_op=[],
500+
exir_op=[],
501+
transform_passes=list(test_case.transform_passes),
502+
tosa_extensions=["bf16"],
503+
atol=1e-2,
504+
rtol=1e-2,
505+
)
506+
pipeline.run()
507+
508+
463509
@common.SkipIfNoModelConverter
464510
@common.parametrize(
465511
"test_case",

0 commit comments

Comments
 (0)