diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 131c85c9ab1..d88397626c8 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -772,6 +772,7 @@ def quantize_per_tensor_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -784,6 +785,7 @@ def quantize_per_tensor_asym8s_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -796,6 +798,7 @@ def quantize_per_tensor_asym8u_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -808,6 +811,7 @@ def quantize_per_tensor_asym16s_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -820,6 +824,7 @@ def quantize_per_tensor_asym16u_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -832,6 +837,7 @@ def quantize_per_tensor_asym32s_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=dtype) @@ -856,6 +862,9 @@ def dequantize_per_tensor_asym8s_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check( + input.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) return input.new_empty(input.size(), dtype=torch.float) @@ -868,6 +877,7 @@ def dequantize_per_tensor_asym8u_meta( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + torch._check(input.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") return input.new_empty(input.size(), dtype=torch.float) @@ -918,6 +928,7 @@ def quantized_add_meta( out_scale: float, out_zero_point: int, ) -> torch.Tensor: + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") # Determine output shape by broadcasting X and Y out_size = torch.broadcast_shapes(X.size(), Y.size()) @@ -951,6 +962,7 @@ def quantized_add_per_tensor_meta( out_scale: float, out_zero_point: int, ) -> torch.Tensor: + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") out_size = torch.broadcast_shapes(X.size(), Y.size()) return X.new_empty(out_size, dtype=X.dtype) @@ -967,6 +979,9 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor_meta( out_scale: float, out_zero_point: int, ) -> torch.Tensor: + torch._check(X.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(Y.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") out_size = torch.broadcast_shapes(X.size(), Y.size()) return X.new_empty(out_size, dtype=X.dtype) @@ -982,6 +997,9 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor_meta( out_scale: float, out_zero_point: int, ) -> torch.Tensor: + torch._check(X.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(Y.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") out_size = torch.broadcast_shapes(X.size(), Y.size()) return X.new_empty(out_size, dtype=X.dtype) @@ -998,6 +1016,8 @@ def quantized_linear_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -1020,6 +1040,8 @@ def quantized_linear_per_tensor_meta( out_zero_point: torch.SymInt, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -1042,6 +1064,12 @@ def quantized_linear_asym8sxasym8s_asym8s_per_tensor_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(src.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check( + weight.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -1064,6 +1092,12 @@ def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(src.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check( + weight.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -1091,6 +1125,7 @@ def quantized_conv2d_nhwc_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 @@ -1143,6 +1178,7 @@ def quantized_conv1d_ncl_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NCL format: input is [N, C, L], weight is [OC, IC/groups, K] out_channels, _, kernel_size = weight.shape @@ -1180,6 +1216,7 @@ def quantized_conv1d_ncl_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NCL format: input is [N, C, L], weight is [OC, IC/groups, K] out_channels, _, kernel_size = weight.shape @@ -1217,6 +1254,7 @@ def quantized_conv1d_nlc_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NLC format: input is [N, L, C], weight is [OC, K, IC/groups] out_channels, kernel_size, _ = weight.shape @@ -1254,6 +1292,7 @@ def quantized_conv1d_nlc_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NLC format: input is [N, L, C], weight is [OC, K, IC/groups] out_channels, kernel_size, _ = weight.shape @@ -1291,6 +1330,7 @@ def quantized_depthwise_conv1d_ncl_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NCL format: input is [N, C, L], weight is [OC, IC/groups, K] out_channels, _, kernel_size = weight.shape @@ -1327,6 +1367,7 @@ def quantized_depthwise_conv1d_nlc_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") # NLC format: input is [N, L, C], weight is [OC, K, IC/groups] out_channels, kernel_size, _ = weight.shape @@ -1363,6 +1404,7 @@ def quantized_conv2d_nchw_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1407,6 +1449,7 @@ def quantized_conv2d_nchw_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1452,6 +1495,7 @@ def quantized_conv2d_nhwc_per_tensor_meta( out_shift: int, offset: Optional[torch.Tensor] = None, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 @@ -2128,6 +2172,7 @@ def quantized_conv2d_depthwise_nhwc_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") in_size = input.shape assert len(in_size) > 2 assert len(in_size) < 6 @@ -2166,6 +2211,8 @@ def quantized_layer_norm_meta( output_scale: float, output_zero_point: int, ) -> torch.Tensor: + torch._check(weight.dtype == torch.float32, lambda: "expected float32") + torch._check(bias.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=input.dtype) @@ -2181,6 +2228,8 @@ def quantized_layer_norm_per_tensor_meta( output_scale: float, output_zero_point: int, ) -> torch.Tensor: + torch._check(weight.dtype == torch.float32, lambda: "expected float32") + torch._check(bias.dtype == torch.float32, lambda: "expected float32") return input.new_empty(input.size(), dtype=input.dtype) @@ -2207,6 +2256,9 @@ def quantized_matmul_meta( out_zero_point: int, transposed: bool = False, ) -> torch.Tensor: + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") + torch._check(X.dim() >= 2, lambda: "expected at least 2D tensor") + torch._check(Y.dim() >= 2, lambda: "expected at least 2D tensor") X_size = list(X.size()) Y_size = list(Y.size()) @@ -2250,6 +2302,11 @@ def quantized_matmul_asym8sxasym8s_asym8s_meta( out_zero_point: int, transposed: bool = False, ) -> torch.Tensor: + torch._check(X.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(Y.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") + torch._check(X.dim() >= 2, lambda: "expected at least 2D tensor") + torch._check(Y.dim() >= 2, lambda: "expected at least 2D tensor") X_size = list(X.size()) Y_size = list(Y.size()) @@ -2293,6 +2350,11 @@ def quantized_matmul_asym8uxasym8u_asym8u_meta( out_zero_point: int, transposed: bool = False, ) -> torch.Tensor: + torch._check(X.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(Y.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check(X.dtype == Y.dtype, lambda: "expected same dtype") + torch._check(X.dim() >= 2, lambda: "expected at least 2D tensor") + torch._check(Y.dim() >= 2, lambda: "expected at least 2D tensor") X_size = list(X.size()) Y_size = list(Y.size()) @@ -2334,6 +2396,7 @@ def im2row_meta( in_zero_point: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: + torch._check(3 <= input.dim() <= 4, lambda: "expected 3-4D tensor") output_size = get_im2row_output_size( input, kernel_size, dilation, padding, stride, channel_last ) @@ -2350,6 +2413,7 @@ def im2row_per_tensor_meta( in_zero_point: int, channel_last: bool = False, ) -> torch.Tensor: + torch._check(3 <= input.dim() <= 4, lambda: "expected 3-4D tensor") output_size = get_im2row_output_size( input, kernel_size, dilation, padding, stride, channel_last ) @@ -2427,6 +2491,9 @@ def quantized_relu_asym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check( + input.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) return input.new_empty(input.size(), dtype=input.dtype) @@ -2438,6 +2505,9 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + torch._check( + input.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) return input.new_empty(input.size(), dtype=input.dtype) @@ -2529,6 +2599,8 @@ def fully_connected_meta( weight: torch.Tensor, bias: torch.Tensor, ) -> torch.Tensor: + torch._check(src.dtype == torch.float32, lambda: "expected float32") + torch._check(src.size(0) == 1, lambda: "expected batch size of 1") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -2551,6 +2623,9 @@ def quantized_fully_connected_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") + torch._check(src.size(0) == 1, lambda: "expected batch size of 1") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -2574,6 +2649,9 @@ def quantized_fully_connected_per_tensor_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") + torch._check(src.size(0) == 1, lambda: "expected batch size of 1") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -2597,6 +2675,13 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(src.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check( + weight.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") + torch._check(src.size(0) == 1, lambda: "expected batch size of 1") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -2620,6 +2705,13 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(src.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype") + torch._check( + weight.dtype in (torch.int8, torch.uint8), lambda: "expected 8-bit dtype" + ) + torch._check(bias.dtype == torch.int32, lambda: "expected int32") + torch._check(weight.dim() == 2, lambda: "expected 2D tensor") + torch._check(src.size(0) == 1, lambda: "expected batch size of 1") # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -2641,6 +2733,7 @@ def conv1d_meta( dilation: Tuple[int], groups: int, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") # Validate tensor dimensions assert len(input.shape) == 3, f"Conv1d expects 3D input, got {len(input.shape)}D" assert len(weight.shape) == 3, f"Conv1d expects 3D weight, got {len(weight.shape)}D" @@ -2687,6 +2780,7 @@ def conv2d_meta( dilation: Tuple[int], groups: int, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") assert ( len(weight.shape) == 4 ), f"Conv2d expects a 4D weight, got {len(weight.shape)}D" @@ -2711,6 +2805,7 @@ def conv3d_meta( dilation: Tuple[int, int, int], groups: int, ) -> torch.Tensor: + torch._check(input.dtype == torch.float32, lambda: "expected float32") assert ( len(weight.shape) == 5 ), f"Conv3d expects a 5D weight, got {len(weight.shape)}D" @@ -2879,6 +2974,7 @@ def avg_pool2d_meta( in_zero_point: Optional[torch.Tensor] = None, channel_last: bool = False, ) -> torch.Tensor: + torch._check(input.dim() == 4, lambda: "expected 4D tensor") # Use torch native meta kernels when operator semantics are similar return torch._meta_registrations.meta_avg_pool2d( input, @@ -2978,6 +3074,8 @@ def rope_meta( cos_tensor: torch.Tensor, pos: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(sin_tensor.dtype == torch.float32, lambda: "expected float32") + torch._check(cos_tensor.dtype == torch.float32, lambda: "expected float32") input_shape = list(input.shape) assert ( len(input_shape) in (4, 5) and input_shape[0] == 1 @@ -3005,6 +3103,8 @@ def rope_rotate_stacked_halves_meta( cos_tensor: torch.Tensor, pos: Optional[torch.Tensor], ) -> torch.Tensor: + torch._check(sin_tensor.dtype == torch.float32, lambda: "expected float32") + torch._check(cos_tensor.dtype == torch.float32, lambda: "expected float32") input_shape = list(input.shape) assert ( len(input_shape) in (4, 5) and input_shape[0] == 1 @@ -3072,6 +3172,8 @@ def roi_align_box_processor_meta( sampling_ratio: int, aligned: bool, ) -> torch.Tensor: + torch._check(rois.dim() == 2, lambda: "expected 2D tensor") + torch._check(rois.size(1) == 5, lambda: "expected dim[1] == 5") return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8) @@ -3121,6 +3223,9 @@ def sdpa_bitwise_mask_gen_meta( mask: torch.Tensor, threshold: float, ) -> torch.Tensor: + torch._check( + mask.dtype in (torch.float32, torch.bool), lambda: "expected float32 or bool" + ) # Expect mask to be a float/bool tensor with last dimension representing sequence length assert mask.dim() >= 1, "mask must have at least 1 dimension" mask_shape = list(mask.shape) diff --git a/backends/cadence/fusion_g3/operators/op_div.cpp b/backends/cadence/fusion_g3/operators/op_div.cpp index 62ebf303ebd..e3610a84baf 100644 --- a/backends/cadence/fusion_g3/operators/op_div.cpp +++ b/backends/cadence/fusion_g3/operators/op_div.cpp @@ -18,6 +18,7 @@ #include #include #include +#include using ::executorch::aten::Scalar; using ::executorch::aten::ScalarType; @@ -34,6 +35,33 @@ namespace native { namespace { +[[maybe_unused]] bool tensor_has_zero(const Tensor& t) { + switch (t.scalar_type()) { + case ScalarType::Float: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](float v) { return v == 0.0f; }); + case ScalarType::Double: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](double v) { return v == 0.0; }); + case ScalarType::Int: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](int32_t v) { return v == 0; }); + case ScalarType::Long: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](int64_t v) { return v == 0; }); + default: + return false; + } +} + ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { if (executorch::runtime::isFloatingType(a_type) && executorch::runtime::isFloatingType(b_type)) { @@ -53,6 +81,8 @@ Tensor& div_out( const Tensor& a, const Tensor& b, Tensor& out) { + ET_DCHECK_MSG(!tensor_has_zero(b), "divisor tensor contains zero"); + #ifdef OP_ARG_CHECK // Check Dim Order ET_KERNEL_CHECK( @@ -228,6 +258,8 @@ Tensor& div_out_mode( const Tensor& b, optional mode, Tensor& out) { + ET_DCHECK_MSG(!tensor_has_zero(b), "divisor tensor contains zero"); + if (!mode.has_value()) { return div_out(ctx, a, b, out); } diff --git a/backends/cadence/hifi/operators/op_div.cpp b/backends/cadence/hifi/operators/op_div.cpp index 057147a2a75..e389dc314ec 100644 --- a/backends/cadence/hifi/operators/op_div.cpp +++ b/backends/cadence/hifi/operators/op_div.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include using executorch::aten::RuntimeContext; @@ -29,6 +30,33 @@ namespace native { namespace { +[[maybe_unused]] bool tensor_has_zero(const Tensor& t) { + switch (t.scalar_type()) { + case ScalarType::Float: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](float v) { return v == 0.0f; }); + case ScalarType::Double: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](double v) { return v == 0.0; }); + case ScalarType::Int: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](int32_t v) { return v == 0; }); + case ScalarType::Long: + return std::any_of( + t.const_data_ptr(), + t.const_data_ptr() + t.numel(), + [](int64_t v) { return v == 0; }); + default: + return false; + } +} + ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { if (executorch::runtime::isFloatingType(a_type) && executorch::runtime::isFloatingType(b_type)) { @@ -45,6 +73,8 @@ ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { Tensor& div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { + ET_DCHECK_MSG(!tensor_has_zero(b), "divisor tensor contains zero"); + ET_KERNEL_CHECK( ctx, torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, @@ -179,6 +209,8 @@ Tensor& div_out_mode( const Tensor& b, std::optional mode, Tensor& out) { + ET_DCHECK_MSG(!tensor_has_zero(b), "divisor tensor contains zero"); + ET_KERNEL_CHECK( ctx, torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok,