Skip to content

Commit 584ef68

Browse files
rascaniclaude
andauthored
Cortex-M: remove extractScalarToInt helpers, use static_cast directly (#17543)
### Summary Since #16322 changed all Scalar parameters to int64_t, the extractScalarToInt32 and extractScalarToInt helpers in cortex_m_ops_common.h are unnecessary indirection — they called Scalar::to<int64_t>() then cast, but the values are already int64_t. Replace all call sites with direct static_cast and remove the helpers, the unused Scalar type alias, and the scalar_utils.h include. Also update the quantized_mul Python schema in operators.py to use int instead of Scalar, matching the C++ kernel and operators.yaml. Additionally remove a duplicate kernel_includes.h include from cortex_m_ops_common.h. ### Test plan ``` pytest backends/cortex_m/test/ ``` cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @psiddh @AdrianLundell Co-authored-by: Claude <noreply@anthropic.com>
1 parent 2cb1ef5 commit 584ef68

4 files changed

Lines changed: 20 additions & 39 deletions

File tree

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1313
#include <executorch/runtime/kernel/kernel_includes.h>
1414

15-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1615
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1716
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
18-
#include <executorch/runtime/kernel/kernel_includes.h>
1917
#include <executorch/runtime/platform/assert.h>
2018

2119
#include <limits>
@@ -27,7 +25,6 @@ extern "C" {
2725

2826
using Tensor = torch::executor::Tensor;
2927
using ScalarType = executorch::aten::ScalarType;
30-
using Scalar = torch::executor::Scalar;
3128
using Error = executorch::runtime::Error;
3229
using Int64ArrayRef = executorch::aten::ArrayRef<int64_t>;
3330
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
@@ -391,19 +388,3 @@ inline Error resize_to_broadcast_target_size(
391388
return executorch::runtime::resize_tensor(
392389
output, {expected_output_size, expected_output_dim});
393390
}
394-
395-
/**
396-
* Convert Scalar to CMSIS-NN int32 format
397-
* For multipliers, zero_points, etc. from quantize_multiplier_aot
398-
*/
399-
inline int32_t extractScalarToInt32(const Scalar& scalar_value) {
400-
return static_cast<int32_t>(scalar_value.to<int64_t>());
401-
}
402-
403-
/**
404-
* Convert Scalar to CMSIS-NN int format
405-
* For shift values from quantize_multiplier_aot
406-
*/
407-
inline int extractScalarToInt(const Scalar& scalar_value) {
408-
return static_cast<int>(scalar_value.to<int64_t>());
409-
}

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ Tensor& quantized_add_out(
6060
"quantized_add_out: input1_int8.sizes() = %zu",
6161
input1_int8.sizes().size());
6262

63-
int32_t zp1 = extractScalarToInt32(input1_zero_point);
64-
int32_t input1_mult = extractScalarToInt32(input1_multiplier);
65-
int input1_shift_val = extractScalarToInt(input1_shift);
66-
int32_t zp2 = extractScalarToInt32(input2_zero_point);
67-
int32_t input2_mult = extractScalarToInt32(input2_multiplier);
68-
int input2_shift_val = extractScalarToInt(input2_shift);
69-
int32_t out_zp = extractScalarToInt32(output_zero_point);
70-
int32_t output_mult = extractScalarToInt32(output_multiplier);
71-
int output_shift_val = extractScalarToInt(output_shift);
63+
int32_t zp1 = static_cast<int32_t>(input1_zero_point);
64+
int32_t input1_mult = static_cast<int32_t>(input1_multiplier);
65+
int input1_shift_val = static_cast<int>(input1_shift);
66+
int32_t zp2 = static_cast<int32_t>(input2_zero_point);
67+
int32_t input2_mult = static_cast<int32_t>(input2_multiplier);
68+
int input2_shift_val = static_cast<int>(input2_shift);
69+
int32_t out_zp = static_cast<int32_t>(output_zero_point);
70+
int32_t output_mult = static_cast<int32_t>(output_multiplier);
71+
int output_shift_val = static_cast<int>(output_shift);
7272
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
7373
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();
7474

backends/cortex_m/ops/op_quantized_mul.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ Tensor& quantized_mul_out(
6161
// Extract quantization parameters
6262
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
6363
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();
64-
int32_t zp1 = extractScalarToInt32(input1_zero_point);
65-
int32_t zp2 = extractScalarToInt32(input2_zero_point);
66-
const int32_t out_zp = extractScalarToInt32(output_zero_point);
67-
const int32_t output_mult = extractScalarToInt32(output_multiplier);
68-
const int32_t output_shift_val = extractScalarToInt32(output_shift);
64+
int32_t zp1 = static_cast<int32_t>(input1_zero_point);
65+
int32_t zp2 = static_cast<int32_t>(input2_zero_point);
66+
const int32_t out_zp = static_cast<int32_t>(output_zero_point);
67+
const int32_t output_mult = static_cast<int32_t>(output_multiplier);
68+
const int32_t output_shift_val = static_cast<int32_t>(output_shift);
6969

7070
int32_t muls_per_loop = 0;
7171

backends/cortex_m/ops/operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,15 @@ def quantized_add_impl(
195195
# ===================================================================
196196
lib.define(
197197
"quantized_mul("
198-
"Tensor self, Scalar self_zero_point, "
199-
"Tensor other, Scalar other_zero_point, "
200-
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
198+
"Tensor self, int self_zero_point, "
199+
"Tensor other, int other_zero_point, "
200+
"int output_zero_point, int output_multiplier, int output_shift) -> Tensor"
201201
)
202202
lib.define(
203203
"quantized_mul.out("
204-
"Tensor self, Scalar self_zero_point, "
205-
"Tensor other, Scalar other_zero_point, "
206-
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
204+
"Tensor self, int self_zero_point, "
205+
"Tensor other, int other_zero_point, "
206+
"int output_zero_point, int output_multiplier, int output_shift, "
207207
"*, Tensor(a!) out) -> Tensor(a!)"
208208
)
209209

0 commit comments

Comments
 (0)