@@ -53,19 +53,16 @@ Tensor& bmm_out(
5353 ScalarType inp_dtype,
5454 int64_t inp_quant_min,
5555 int64_t inp_quant_max,
56- optional<int64_t > inp_axis,
5756 const optional<Tensor>& other_scale,
5857 const optional<Tensor>& other_zero_point,
5958 ScalarType other_dtype,
6059 int64_t other_quant_min,
6160 int64_t other_quant_max,
62- optional<int64_t > other_axis,
6361 const optional<Tensor>& out_scale,
6462 const optional<Tensor>& out_zero_point,
6563 ScalarType out_dtype,
6664 int64_t out_quant_min,
6765 int64_t out_quant_max,
68- optional<int64_t > out_axis,
6966 Tensor& out) {
7067 int64_t batch = inp.size (0 );
7168 int64_t M = inp.size (1 );
@@ -87,7 +84,7 @@ Tensor& bmm_out(
8784 }
8885 inp_buf.resize (inp_numel);
8986 QParams qp = extract_qparams (
90- inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
87+ inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp);
9188 FUSED_QUANT_DTYPE_SWITCH (
9289 inp.scalar_type (),
9390 scalar_t ,
@@ -104,12 +101,7 @@ Tensor& bmm_out(
104101 }
105102 other_buf.resize (other_numel);
106103 QParams qp = extract_qparams (
107- other_scale,
108- other_zero_point,
109- other_quant_min,
110- other_quant_max,
111- other_axis,
112- other);
104+ other_scale, other_zero_point, other_quant_min, other_quant_max, other);
113105 FUSED_QUANT_DTYPE_SWITCH (other.scalar_type (),
114106 scalar_t ,
115107 dequantize_buffer (
@@ -126,7 +118,7 @@ Tensor& bmm_out(
126118 bmm_kernel (inp_float, other_float, result_float.data (), batch, M, K, N);
127119
128120 QParams qp = extract_qparams (
129- out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
121+ out_scale, out_zero_point, out_quant_min, out_quant_max, out);
130122 FUSED_QUANT_DTYPE_SWITCH (out.scalar_type (),
131123 scalar_t ,
132124 quantize_buffer (
0 commit comments