Skip to content

Commit 86e861c

Browse files
authored
Update fused quant broadcast logic (pytorch#20171)
Differential Revision: D108065588 Pull Request resolved: pytorch#20171
1 parent 39dade2 commit 86e861c

16 files changed

Lines changed: 105 additions & 236 deletions

backends/cadence/fused_quant/op_add.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,16 @@ Tensor& add_out(
4343
ScalarType inp_dtype,
4444
int64_t inp_quant_min,
4545
int64_t inp_quant_max,
46-
optional<int64_t> inp_axis,
4746
const optional<Tensor>& other_scale,
4847
const optional<Tensor>& other_zero_point,
4948
ScalarType other_dtype,
5049
int64_t other_quant_min,
5150
int64_t other_quant_max,
52-
optional<int64_t> other_axis,
5351
const optional<Tensor>& out_scale,
5452
const optional<Tensor>& out_zero_point,
5553
ScalarType out_dtype,
5654
int64_t out_quant_min,
5755
int64_t out_quant_max,
58-
optional<int64_t> out_axis,
5956
double alpha,
6057
Tensor& out) {
6158
int64_t numel = inp.numel();
@@ -72,7 +69,7 @@ Tensor& add_out(
7269
}
7370
inp_buf.resize(numel);
7471
QParams qp = extract_qparams(
75-
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
72+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
7673
FUSED_QUANT_DTYPE_SWITCH(
7774
inp.scalar_type(),
7875
scalar_t,
@@ -88,12 +85,7 @@ Tensor& add_out(
8885
}
8986
other_buf.resize(numel);
9087
QParams qp = extract_qparams(
91-
other_scale,
92-
other_zero_point,
93-
other_quant_min,
94-
other_quant_max,
95-
other_axis,
96-
other);
88+
other_scale, other_zero_point, other_quant_min, other_quant_max, other);
9789
FUSED_QUANT_DTYPE_SWITCH(
9890
other.scalar_type(),
9991
scalar_t,
@@ -107,7 +99,7 @@ Tensor& add_out(
10799
add_kernel(inp_float, other_float, result_float.data(), numel, alpha_f);
108100

109101
QParams qp = extract_qparams(
110-
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
102+
out_scale, out_zero_point, out_quant_min, out_quant_max, out);
111103
FUSED_QUANT_DTYPE_SWITCH(
112104
out.scalar_type(),
113105
scalar_t,

backends/cadence/fused_quant/op_add.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@ executorch::aten::Tensor& add_out(
2424
executorch::aten::ScalarType inp_dtype,
2525
int64_t inp_quant_min,
2626
int64_t inp_quant_max,
27-
executorch::aten::optional<int64_t> inp_axis,
2827
const executorch::aten::optional<executorch::aten::Tensor>& other_scale,
2928
const executorch::aten::optional<executorch::aten::Tensor>&
3029
other_zero_point,
3130
executorch::aten::ScalarType other_dtype,
3231
int64_t other_quant_min,
3332
int64_t other_quant_max,
34-
executorch::aten::optional<int64_t> other_axis,
3533
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
3634
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
3735
executorch::aten::ScalarType out_dtype,
3836
int64_t out_quant_min,
3937
int64_t out_quant_max,
40-
executorch::aten::optional<int64_t> out_axis,
4138
double alpha,
4239
executorch::aten::Tensor& out);
4340

backends/cadence/fused_quant/op_bmm.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,16 @@ Tensor& bmm_out(
5353
ScalarType inp_dtype,
5454
int64_t inp_quant_min,
5555
int64_t inp_quant_max,
56-
optional<int64_t> inp_axis,
5756
const optional<Tensor>& other_scale,
5857
const optional<Tensor>& other_zero_point,
5958
ScalarType other_dtype,
6059
int64_t other_quant_min,
6160
int64_t other_quant_max,
62-
optional<int64_t> other_axis,
6361
const optional<Tensor>& out_scale,
6462
const optional<Tensor>& out_zero_point,
6563
ScalarType out_dtype,
6664
int64_t out_quant_min,
6765
int64_t out_quant_max,
68-
optional<int64_t> out_axis,
6966
Tensor& out) {
7067
int64_t batch = inp.size(0);
7168
int64_t M = inp.size(1);
@@ -87,7 +84,7 @@ Tensor& bmm_out(
8784
}
8885
inp_buf.resize(inp_numel);
8986
QParams qp = extract_qparams(
90-
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
87+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
9188
FUSED_QUANT_DTYPE_SWITCH(
9289
inp.scalar_type(),
9390
scalar_t,
@@ -104,12 +101,7 @@ Tensor& bmm_out(
104101
}
105102
other_buf.resize(other_numel);
106103
QParams qp = extract_qparams(
107-
other_scale,
108-
other_zero_point,
109-
other_quant_min,
110-
other_quant_max,
111-
other_axis,
112-
other);
104+
other_scale, other_zero_point, other_quant_min, other_quant_max, other);
113105
FUSED_QUANT_DTYPE_SWITCH(other.scalar_type(),
114106
scalar_t,
115107
dequantize_buffer(
@@ -126,7 +118,7 @@ Tensor& bmm_out(
126118
bmm_kernel(inp_float, other_float, result_float.data(), batch, M, K, N);
127119

128120
QParams qp = extract_qparams(
129-
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
121+
out_scale, out_zero_point, out_quant_min, out_quant_max, out);
130122
FUSED_QUANT_DTYPE_SWITCH(out.scalar_type(),
131123
scalar_t,
132124
quantize_buffer(

backends/cadence/fused_quant/op_bmm.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@ executorch::aten::Tensor& bmm_out(
2424
executorch::aten::ScalarType inp_dtype,
2525
int64_t inp_quant_min,
2626
int64_t inp_quant_max,
27-
executorch::aten::optional<int64_t> inp_axis,
2827
const executorch::aten::optional<executorch::aten::Tensor>& other_scale,
2928
const executorch::aten::optional<executorch::aten::Tensor>&
3029
other_zero_point,
3130
executorch::aten::ScalarType other_dtype,
3231
int64_t other_quant_min,
3332
int64_t other_quant_max,
34-
executorch::aten::optional<int64_t> other_axis,
3533
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
3634
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
3735
executorch::aten::ScalarType out_dtype,
3836
int64_t out_quant_min,
3937
int64_t out_quant_max,
40-
executorch::aten::optional<int64_t> out_axis,
4138
executorch::aten::Tensor& out);
4239

4340
} // namespace native

backends/cadence/fused_quant/op_hardswish.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,11 @@ Tensor& hardswish_out(
4040
ScalarType inp_dtype,
4141
int64_t inp_quant_min,
4242
int64_t inp_quant_max,
43-
optional<int64_t> inp_axis,
4443
const optional<Tensor>& out_scale,
4544
const optional<Tensor>& out_zero_point,
4645
ScalarType out_dtype,
4746
int64_t out_quant_min,
4847
int64_t out_quant_max,
49-
optional<int64_t> out_axis,
5048
Tensor& out) {
5149
int64_t numel = inp.numel();
5250

@@ -60,7 +58,7 @@ Tensor& hardswish_out(
6058
}
6159
inp_buf.resize(numel);
6260
QParams qp = extract_qparams(
63-
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
61+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
6462
FUSED_QUANT_DTYPE_SWITCH(
6563
inp.scalar_type(),
6664
scalar_t,
@@ -74,7 +72,7 @@ Tensor& hardswish_out(
7472
hardswish_kernel(inp_float, result_float.data(), numel);
7573

7674
QParams qp = extract_qparams(
77-
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
75+
out_scale, out_zero_point, out_quant_min, out_quant_max, out);
7876
FUSED_QUANT_DTYPE_SWITCH(
7977
out.scalar_type(),
8078
scalar_t,

backends/cadence/fused_quant/op_hardswish.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ executorch::aten::Tensor& hardswish_out(
2323
executorch::aten::ScalarType inp_dtype,
2424
int64_t inp_quant_min,
2525
int64_t inp_quant_max,
26-
executorch::aten::optional<int64_t> inp_axis,
2726
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
2827
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
2928
executorch::aten::ScalarType out_dtype,
3029
int64_t out_quant_min,
3130
int64_t out_quant_max,
32-
executorch::aten::optional<int64_t> out_axis,
3331
executorch::aten::Tensor& out);
3432

3533
} // namespace native

backends/cadence/fused_quant/op_mul.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,16 @@ Tensor& mul_out(
4242
ScalarType inp_dtype,
4343
int64_t inp_quant_min,
4444
int64_t inp_quant_max,
45-
optional<int64_t> inp_axis,
4645
const optional<Tensor>& other_scale,
4746
const optional<Tensor>& other_zero_point,
4847
ScalarType other_dtype,
4948
int64_t other_quant_min,
5049
int64_t other_quant_max,
51-
optional<int64_t> other_axis,
5250
const optional<Tensor>& out_scale,
5351
const optional<Tensor>& out_zero_point,
5452
ScalarType out_dtype,
5553
int64_t out_quant_min,
5654
int64_t out_quant_max,
57-
optional<int64_t> out_axis,
5855
Tensor& out) {
5956
(void)ctx;
6057
(void)inp_dtype;
@@ -74,7 +71,7 @@ Tensor& mul_out(
7471
}
7572
inp_buf.resize(numel);
7673
QParams qp = extract_qparams(
77-
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
74+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
7875
FUSED_QUANT_DTYPE_SWITCH(
7976
inp.scalar_type(),
8077
scalar_t,
@@ -90,12 +87,7 @@ Tensor& mul_out(
9087
}
9188
other_buf.resize(numel);
9289
QParams qp = extract_qparams(
93-
other_scale,
94-
other_zero_point,
95-
other_quant_min,
96-
other_quant_max,
97-
other_axis,
98-
other);
90+
other_scale, other_zero_point, other_quant_min, other_quant_max, other);
9991
FUSED_QUANT_DTYPE_SWITCH(
10092
other.scalar_type(),
10193
scalar_t,
@@ -109,7 +101,7 @@ Tensor& mul_out(
109101
mul_kernel(inp_float, other_float, result_float.data(), numel);
110102

111103
QParams qp = extract_qparams(
112-
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
104+
out_scale, out_zero_point, out_quant_min, out_quant_max, out);
113105
FUSED_QUANT_DTYPE_SWITCH(
114106
out.scalar_type(),
115107
scalar_t,

backends/cadence/fused_quant/op_mul.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@ executorch::aten::Tensor& mul_out(
2424
executorch::aten::ScalarType inp_dtype,
2525
int64_t inp_quant_min,
2626
int64_t inp_quant_max,
27-
executorch::aten::optional<int64_t> inp_axis,
2827
const executorch::aten::optional<executorch::aten::Tensor>& other_scale,
2928
const executorch::aten::optional<executorch::aten::Tensor>&
3029
other_zero_point,
3130
executorch::aten::ScalarType other_dtype,
3231
int64_t other_quant_min,
3332
int64_t other_quant_max,
34-
executorch::aten::optional<int64_t> other_axis,
3533
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
3634
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
3735
executorch::aten::ScalarType out_dtype,
3836
int64_t out_quant_min,
3937
int64_t out_quant_max,
40-
executorch::aten::optional<int64_t> out_axis,
4138
executorch::aten::Tensor& out);
4239

4340
} // namespace native

backends/cadence/fused_quant/op_relu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,11 @@ Tensor& relu_out(
3939
ScalarType inp_dtype,
4040
int64_t inp_quant_min,
4141
int64_t inp_quant_max,
42-
optional<int64_t> inp_axis,
4342
const optional<Tensor>& out_scale,
4443
const optional<Tensor>& out_zero_point,
4544
ScalarType out_dtype,
4645
int64_t out_quant_min,
4746
int64_t out_quant_max,
48-
optional<int64_t> out_axis,
4947
Tensor& out) {
5048
int64_t numel = inp.numel();
5149

@@ -59,7 +57,7 @@ Tensor& relu_out(
5957
}
6058
inp_buf.resize(numel);
6159
QParams qp = extract_qparams(
62-
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
60+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
6361
FUSED_QUANT_DTYPE_SWITCH(
6462
inp.scalar_type(),
6563
scalar_t,
@@ -73,7 +71,7 @@ Tensor& relu_out(
7371
relu_kernel(inp_float, result_float.data(), numel);
7472

7573
QParams qp = extract_qparams(
76-
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
74+
out_scale, out_zero_point, out_quant_min, out_quant_max, out);
7775
FUSED_QUANT_DTYPE_SWITCH(
7876
out.scalar_type(),
7977
scalar_t,

backends/cadence/fused_quant/op_relu.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ executorch::aten::Tensor& relu_out(
2323
executorch::aten::ScalarType inp_dtype,
2424
int64_t inp_quant_min,
2525
int64_t inp_quant_max,
26-
executorch::aten::optional<int64_t> inp_axis,
2726
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
2827
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
2928
executorch::aten::ScalarType out_dtype,
3029
int64_t out_quant_min,
3130
int64_t out_quant_max,
32-
executorch::aten::optional<int64_t> out_axis,
3331
executorch::aten::Tensor& out);
3432

3533
} // namespace native

0 commit comments

Comments
 (0)