Skip to content

Commit 573f930

Browse files
authored
Fix operator precedence bug in cat_out dtype check
Differential Revision: D98801435 Pull Request resolved: #18596
1 parent 433569a commit 573f930

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

backends/cadence/fusion_g3/operators/op_cat.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,23 @@ Tensor& cat_out(
100100
out_shapes[i] = out_size[i];
101101
}
102102

103-
bool optimized = true;
104-
103+
bool all_same_dtype = true;
105104
for (int i = 0; i < tensors.size(); i++) {
106105
if (out.scalar_type() != tensors[i].scalar_type()) {
107-
optimized = false;
106+
all_same_dtype = false;
108107
break;
109108
}
110109
}
111110

112-
if ((optimized) && (out.scalar_type() == ScalarType::Int) ||
113-
(out.scalar_type() == ScalarType::Short) ||
114-
(out.scalar_type() == ScalarType::Char) ||
115-
(out.scalar_type() == ScalarType::UInt32) ||
116-
(out.scalar_type() == ScalarType::UInt16) ||
117-
(out.scalar_type() == ScalarType::Byte) ||
118-
(out.scalar_type() == ScalarType::Float)) {
111+
bool supported_dtype = out.scalar_type() == ScalarType::Int ||
112+
out.scalar_type() == ScalarType::Short ||
113+
out.scalar_type() == ScalarType::Char ||
114+
out.scalar_type() == ScalarType::UInt32 ||
115+
out.scalar_type() == ScalarType::UInt16 ||
116+
out.scalar_type() == ScalarType::Byte ||
117+
out.scalar_type() == ScalarType::Float;
118+
119+
if (all_same_dtype && supported_dtype) {
119120
XT_KERNEL_CHECK(
120121
ctx,
121122
out,

0 commit comments

Comments
 (0)