Skip to content

Commit 03e14ef

Browse files
authored
Arm backend: Add bf16 support for aten.index_select and aten.unfold_copy (#19751)
Follow-up to #17097, which added BF16 support to the TOSA GATHER op. `aten.index_select` and `aten.unfold_copy` both lower via TOSA GATHER but their support checks were not updated at the time. In both decompositions(`DecomposeIndexSelectToGatherPass()` and `DecomposeUnfoldToGatherPass()`), the bf16 values tensor flows through dtype-agnostic reshape ops and `tosa.GATHER`, which accepts `BF16`. The support check was the only blocker. | Op | bf16 before | bf16 after | |---------------------|:-----------:|:----------:| | `aten.gather` | ✅ | ✅ | | `aten.index.Tensor` | ✅ | ✅ | | `aten.slice_copy` | ✅ | ✅ | | `aten.index_select` | ❌ | ✅ | | `aten.unfold_copy` | ❌ | ✅ | Changes: - `index_select_support.py`, `unfold_copy_support.py`: extend float branch to include `bfloat16`; add bf16 extension guard; update rejection message. - `test_index_select.py`, `test_unfold_copy.py`: add isolated `_tosa_FP_bf16` test functions using `TosaPipelineFP(..., tosa_extensions=["bf16"])`. ### Test plan `test_index_select_tosa_FP_bf16` and `test_unfold_copy_tosa_FP_bf16` exercise the bf16 path end-to-end through `TosaPipelineFP` with the bf16 extension enabled, following the same pattern of the existing `test_slice_tensor_tosa_FP_bf16` from #17492
1 parent b73df0b commit 03e14ef

4 files changed

Lines changed: 78 additions & 6 deletions

File tree

backends/arm/operator_support/index_select_support.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,16 @@ def is_node_tosa_supported(
7777
f"{node.target}: dtype {values_dtype} requires INT profile.",
7878
)
7979
return False
80-
# fp16/fp32: either FP profile, or INT profile (via quantization)
81-
elif values_dtype in (torch.float16, torch.float32):
80+
# fp16/fp32/bf16: either FP profile, or INT profile (via quantization)
81+
elif values_dtype in (torch.float16, torch.float32, torch.bfloat16):
82+
if values_dtype == torch.bfloat16 and not tosa_spec.support_extension(
83+
"bf16"
84+
):
85+
self.reporter.report_reject(
86+
node,
87+
f"{node.target}: dtype {values_dtype} requires bf16 extension.",
88+
)
89+
return False
8290
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
8391
self.reporter.report_reject(
8492
node,
@@ -90,7 +98,7 @@ def is_node_tosa_supported(
9098
self.reporter.report_reject(
9199
node,
92100
f"{node.target}: unsupported values dtype {values_dtype}; "
93-
"expected bool/int8/int16/int32/float16/float32.",
101+
"expected bool/int8/int16/int32/float16/bfloat16/float32.",
94102
)
95103
return False
96104

backends/arm/operator_support/unfold_copy_support.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,16 @@ def is_node_tosa_supported(
8484
f"{node.target}: dtype {values_dtype} requires INT profile.",
8585
)
8686
return False
87-
# fp16/fp32: either FP profile, or INT profile (via quantization)
88-
elif values_dtype in (torch.float16, torch.float32):
87+
# fp16/fp32/bf16: either FP profile, or INT profile (via quantization)
88+
elif values_dtype in (torch.float16, torch.float32, torch.bfloat16):
89+
if values_dtype == torch.bfloat16 and not tosa_spec.support_extension(
90+
"bf16"
91+
):
92+
self.reporter.report_reject(
93+
node,
94+
f"{node.target}: dtype {values_dtype} requires bf16 extension.",
95+
)
96+
return False
8997
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
9098
self.reporter.report_reject(
9199
node,
@@ -97,7 +105,7 @@ def is_node_tosa_supported(
97105
self.reporter.report_reject(
98106
node,
99107
f"{node.target}: unsupported values dtype {values_dtype}; "
100-
"expected bool/int8/int16/int32/float16/float32.",
108+
"expected bool/int8/int16/int32/float16/bfloat16/float32.",
101109
)
102110
return False
103111

backends/arm/test/ops/test_index_select.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ def forward(self, input_: torch.Tensor, dim: int, index_: torch.Tensor):
6161
torch.tensor([3, 1], dtype=torch.int32), # [W=2]
6262
),
6363
}
64+
test_data_fp_bf16: dict[str, input_params] = {
65+
# Rank-2: [K, C] -> index_select dim=0 => [W, C]
66+
"test_bf16_rank2_dim0": (
67+
torch.tensor(
68+
[[0.5, 1.25, 2.5], [3.5, 4.25, 5.75], [6.5, 7.25, 8.75]],
69+
dtype=torch.bfloat16,
70+
), # [K=3, C=3]
71+
0,
72+
torch.tensor([2, 0], dtype=torch.int32), # [W=2]
73+
),
74+
# Rank-3: [N, K, C] -> index_select dim=-1 => [N, K, W]
75+
"test_bf16_rank3_dim_neg1": (
76+
torch.tensor(
77+
[[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]],
78+
dtype=torch.bfloat16,
79+
), # [N=2, K=2, C=2]
80+
-1,
81+
torch.tensor([1, 0], dtype=torch.int32), # [W=2]
82+
),
83+
}
6484

6585
# ---- INT profile: integer inputs + bool ----
6686
test_data_int: dict[str, input_params] = {
@@ -104,6 +124,18 @@ def test_index_select_tosa_FP(test_data: input_params):
104124
pipeline.run()
105125

106126

127+
@common.parametrize("test_data", test_data_fp_bf16)
128+
def test_index_select_tosa_FP_bf16(test_data: input_params):
129+
pipeline = TosaPipelineFP[input_params](
130+
IndexSelect(),
131+
test_data,
132+
aten_op=IndexSelect.aten_op,
133+
exir_op=IndexSelect.exir_op,
134+
tosa_extensions=["bf16"],
135+
)
136+
pipeline.run()
137+
138+
107139
@common.parametrize("test_data", test_data_int | test_data_fp)
108140
def test_index_select_tosa_INT(test_data: input_params):
109141
# INT profile runs quantized, so we test both int inputs and float inputs here.

backends/arm/test/ops/test_unfold_copy.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,18 @@ def forward(self, input_: torch.Tensor, dim_: int, size_: int, step_: int):
120120
),
121121
}
122122

123+
test_data_bf16: dict[str, input_params] = {
124+
"test_bf16_2d_dim1": (
125+
torch.tensor(
126+
[[0.1, 0.2, 0.3, 0.4, 0.5], [1.1, 1.2, 1.3, 1.4, 1.5]],
127+
dtype=torch.bfloat16,
128+
), # [B=2, T=5]
129+
1,
130+
3,
131+
2, # U=(5-3)//2+1=2 -> [B=2, U=2, C=3]
132+
),
133+
}
134+
123135

124136
@common.parametrize("test_data", test_data_fp)
125137
def test_unfold_copy_tosa_FP(test_data: input_params):
@@ -132,6 +144,18 @@ def test_unfold_copy_tosa_FP(test_data: input_params):
132144
pipeline.run()
133145

134146

147+
@common.parametrize("test_data", test_data_bf16)
148+
def test_unfold_copy_tosa_FP_bf16(test_data: input_params):
149+
pipeline = TosaPipelineFP[input_params](
150+
UnfoldCopy(),
151+
test_data,
152+
aten_op=UnfoldCopy.aten_op,
153+
exir_op=UnfoldCopy.exir_op,
154+
tosa_extensions=["bf16"],
155+
)
156+
pipeline.run()
157+
158+
135159
@common.parametrize("test_data", test_data_int | test_data_fp)
136160
def test_unfold_copy_tosa_INT(test_data: input_params):
137161
pipeline = TosaPipelineINT[input_params](

0 commit comments

Comments
 (0)