Skip to content

Commit c229dbc

Browse files
Add aten::_conj_physical.out operator (#17207) (#17383)
Summary: Added `aten::_conj_physical.out` which implements https://docs.pytorch.org/docs/stable/generated/torch.conj_physical.html Differential Revision: D92991877
1 parent b871398 commit c229dbc

6 files changed

Lines changed: 222 additions & 5 deletions

File tree

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
- op: _cdist_forward.out
66

7+
- op: _conj_physical.out
8+
79
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
810

911
- op: _fft_c2r.out
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/functional_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using executorch::aten::Tensor;
17+
18+
Tensor&
19+
_conj_physical_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20+
ET_KERNEL_CHECK_MSG(
21+
ctx,
22+
resize_tensor(out, in.sizes()) == Error::Ok,
23+
InvalidArgument,
24+
out,
25+
"Failed to resize output tensor.");
26+
27+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
28+
29+
ET_KERNEL_CHECK(
30+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
31+
32+
// @lint-ignore CLANGTIDY facebook-hte-CArray
33+
static constexpr const char op_name[] = "_conj_physical.out";
34+
35+
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
36+
apply_unary_map_fn<CTYPE, CTYPE>(
37+
[](const CTYPE val_in) -> CTYPE {
38+
return CTYPE(val_in.real_, -val_in.imag_);
39+
},
40+
in.const_data_ptr<CTYPE>(),
41+
out.mutable_data_ptr<CTYPE>(),
42+
in.numel());
43+
});
44+
45+
return out;
46+
}
47+
48+
} // namespace native
49+
} // namespace executor
50+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
- arg_meta: null
2323
kernel_name: torch::executor::_cdist_forward_out
2424

25+
- op: _conj_physical.out
26+
kernels:
27+
- arg_meta: null
28+
kernel_name: torch::executor::_conj_physical_out
29+
2530
- op: _log_softmax.out
2631
kernels:
2732
- arg_meta: null
@@ -1043,4 +1048,4 @@
10431048
- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
10441049
kernels:
10451050
- arg_meta: null
1046-
kernel_name: torch::executor::_clone_dim_order_out
1051+
kernel_name: torch::executor::_clone_dim_order_out
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using torch::executor::testing::TensorFactory;
21+
22+
class OpConjPhysicalOutTest : public OperatorTest {
23+
protected:
24+
Tensor& op_conj_physical_out(const Tensor& in, Tensor& out) {
25+
return torch::executor::aten::_conj_physical_outf(context_, in, out);
26+
}
27+
};
28+
29+
TEST_F(OpConjPhysicalOutTest, ComplexFloatBasic) {
30+
TensorFactory<ScalarType::ComplexFloat> tf;
31+
32+
const std::vector<int32_t> sizes = {2, 2};
33+
34+
// Create input: (1+2i), (3+4i), (5-6i), (-7+8i)
35+
Tensor in = tf.make(
36+
sizes,
37+
{executorch::aten::complex<float>(1.0f, 2.0f),
38+
executorch::aten::complex<float>(3.0f, 4.0f),
39+
executorch::aten::complex<float>(5.0f, -6.0f),
40+
executorch::aten::complex<float>(-7.0f, 8.0f)});
41+
42+
Tensor out = tf.zeros(sizes);
43+
44+
op_conj_physical_out(in, out);
45+
46+
// Expected: (1-2i), (3-4i), (5+6i), (-7-8i)
47+
Tensor expected = tf.make(
48+
sizes,
49+
{executorch::aten::complex<float>(1.0f, -2.0f),
50+
executorch::aten::complex<float>(3.0f, -4.0f),
51+
executorch::aten::complex<float>(5.0f, 6.0f),
52+
executorch::aten::complex<float>(-7.0f, -8.0f)});
53+
54+
EXPECT_TENSOR_EQ(out, expected);
55+
}
56+
57+
TEST_F(OpConjPhysicalOutTest, ComplexDoubleBasic) {
58+
TensorFactory<ScalarType::ComplexDouble> tf;
59+
60+
const std::vector<int32_t> sizes = {3};
61+
62+
Tensor in = tf.make(
63+
sizes,
64+
{executorch::aten::complex<double>(1.5, 2.5),
65+
executorch::aten::complex<double>(-3.5, 4.5),
66+
executorch::aten::complex<double>(0.0, -1.0)});
67+
68+
Tensor out = tf.zeros(sizes);
69+
70+
op_conj_physical_out(in, out);
71+
72+
Tensor expected = tf.make(
73+
sizes,
74+
{executorch::aten::complex<double>(1.5, -2.5),
75+
executorch::aten::complex<double>(-3.5, -4.5),
76+
executorch::aten::complex<double>(0.0, 1.0)});
77+
78+
EXPECT_TENSOR_EQ(out, expected);
79+
}
80+
81+
TEST_F(OpConjPhysicalOutTest, RealPartOnly) {
82+
TensorFactory<ScalarType::ComplexFloat> tf;
83+
84+
const std::vector<int32_t> sizes = {2};
85+
86+
// When imaginary part is zero, conjugate negates the imaginary part (0 -> -0)
87+
// Both are mathematically equivalent, so we verify values directly
88+
Tensor in = tf.make(
89+
sizes,
90+
{executorch::aten::complex<float>(5.0f, 0.0f),
91+
executorch::aten::complex<float>(-3.0f, 0.0f)});
92+
93+
Tensor out = tf.zeros(sizes);
94+
95+
op_conj_physical_out(in, out);
96+
97+
// Verify real parts are unchanged and imaginary parts are negated zeros
98+
const auto* out_data = out.const_data_ptr<executorch::aten::complex<float>>();
99+
EXPECT_EQ(out_data[0].real_, 5.0f);
100+
EXPECT_EQ(out_data[0].imag_, -0.0f);
101+
EXPECT_EQ(out_data[1].real_, -3.0f);
102+
EXPECT_EQ(out_data[1].imag_, -0.0f);
103+
}
104+
105+
TEST_F(OpConjPhysicalOutTest, ImaginaryPartOnly) {
106+
TensorFactory<ScalarType::ComplexFloat> tf;
107+
108+
const std::vector<int32_t> sizes = {2};
109+
110+
Tensor in = tf.make(
111+
sizes,
112+
{executorch::aten::complex<float>(0.0f, 5.0f),
113+
executorch::aten::complex<float>(0.0f, -3.0f)});
114+
115+
Tensor out = tf.zeros(sizes);
116+
117+
op_conj_physical_out(in, out);
118+
119+
Tensor expected = tf.make(
120+
sizes,
121+
{executorch::aten::complex<float>(0.0f, -5.0f),
122+
executorch::aten::complex<float>(0.0f, 3.0f)});
123+
124+
EXPECT_TENSOR_EQ(out, expected);
125+
}
126+
127+
TEST_F(OpConjPhysicalOutTest, EmptyTensor) {
128+
TensorFactory<ScalarType::ComplexFloat> tf;
129+
130+
const std::vector<int32_t> sizes = {0};
131+
132+
Tensor in = tf.make(sizes, {});
133+
Tensor out = tf.zeros(sizes);
134+
135+
op_conj_physical_out(in, out);
136+
137+
EXPECT_EQ(out.numel(), 0);
138+
}
139+
140+
TEST_F(OpConjPhysicalOutTest, MismatchedDtypeDies) {
141+
TensorFactory<ScalarType::ComplexFloat> tf_in;
142+
TensorFactory<ScalarType::ComplexDouble> tf_out;
143+
144+
const std::vector<int32_t> sizes = {2};
145+
146+
Tensor in = tf_in.make(
147+
sizes,
148+
{executorch::aten::complex<float>(1.0f, 2.0f),
149+
executorch::aten::complex<float>(3.0f, 4.0f)});
150+
Tensor out = tf_out.zeros(sizes);
151+
152+
ET_EXPECT_KERNEL_FAILURE(context_, op_conj_physical_out(in, out));
153+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def define_common_targets():
166166
_common_op_test("op__to_dim_order_copy_test", ["aten", "portable"])
167167
_common_op_test("op__empty_dim_order_test", ["aten", "portable"])
168168
_common_op_test("op__clone_dim_order_test", ["aten", "portable"])
169+
_common_op_test("op__conj_physical_test", ["aten", "portable"])
169170
_common_op_test("op_abs_test", ["aten", "portable"])
170171
_common_op_test("op_acos_test", ["aten", "portable"])
171172
_common_op_test("op_acosh_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,20 +1356,26 @@ ATEN_OPS = (
13561356
name = "op_zeros",
13571357
),
13581358
op_target(
1359-
name = "op__empty_dim_order",
1359+
name = "op__clone_dim_order",
13601360
deps = [
13611361
":scalar_utils",
1362+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
13621363
],
13631364
),
13641365
op_target(
1365-
name = "op__to_dim_order_copy",
1366+
name = "op__conj_physical",
1367+
deps = [
1368+
"//executorch/kernels/portable/cpu/util:functional_util",
1369+
],
1370+
),
1371+
op_target(
1372+
name = "op__empty_dim_order",
13661373
deps = [
13671374
":scalar_utils",
1368-
"//executorch/kernels/portable/cpu/util:copy_ops_util",
13691375
],
13701376
),
13711377
op_target(
1372-
name = "op__clone_dim_order",
1378+
name = "op__to_dim_order_copy",
13731379
deps = [
13741380
":scalar_utils",
13751381
"//executorch/kernels/portable/cpu/util:copy_ops_util",

0 commit comments

Comments
 (0)