@@ -50,24 +50,47 @@ Tensor& add_out(
5050 // @lint-ignore CLANGTIDY facebook-hte-CArray
5151 static constexpr const char op_name[] = " add.out" ;
5252
53- ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54- CTYPE_COMPUTE val_alpha;
53+ if (executorch::runtime::isComplexType (a.scalar_type ()) ||
54+ executorch::runtime::isComplexType (b.scalar_type ()) ||
55+ executorch::runtime::isComplexType (out.scalar_type ())) {
56+ // TODO: The current support for complex dtype enforces that input and
57+ // output tensors have the same dtype. Support mixed dtypes in the future.
5558 ET_KERNEL_CHECK (
56- ctx, utils::extract_scalar (alpha, &val_alpha), InvalidArgument, );
57- utils::apply_bitensor_elementwise_fn<
58- CTYPE_COMPUTE,
59- op_name,
60- utils::SupportedTensorDtypes::REALHBBF16>(
61- [val_alpha](const auto val_a, const auto val_b) {
62- return val_a + val_alpha * val_b;
63- },
6459 ctx,
65- a,
66- utils::SupportedTensorDtypes::REALHBBF16,
67- b,
68- utils::SupportedTensorDtypes::REALHBBF16,
60+ a.scalar_type () == b.scalar_type () &&
61+ a.scalar_type () == out.scalar_type (),
62+ InvalidArgument,
6963 out);
70- });
64+ ET_SWITCH_COMPLEXH_TYPES (out.scalar_type (), ctx, op_name, CTYPE, [&]() {
65+ CTYPE val_alpha = utils::scalar_to<CTYPE>(alpha);
66+ apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
67+ [val_alpha](const CTYPE val_a, const CTYPE val_b) {
68+ return val_a + val_alpha * val_b;
69+ },
70+ a,
71+ b,
72+ out);
73+ });
74+ } else {
75+ ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
76+ CTYPE_COMPUTE val_alpha;
77+ ET_KERNEL_CHECK (
78+ ctx, utils::extract_scalar (alpha, &val_alpha), InvalidArgument, );
79+ utils::apply_bitensor_elementwise_fn<
80+ CTYPE_COMPUTE,
81+ op_name,
82+ utils::SupportedTensorDtypes::REALHBBF16>(
83+ [val_alpha](const auto val_a, const auto val_b) {
84+ return val_a + val_alpha * val_b;
85+ },
86+ ctx,
87+ a,
88+ utils::SupportedTensorDtypes::REALHBBF16,
89+ b,
90+ utils::SupportedTensorDtypes::REALHBBF16,
91+ out);
92+ });
93+ }
7194
7295 return out;
7396}
0 commit comments