File tree Expand file tree Collapse file tree
backends/cadence/fusion_g3/operators Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -100,22 +100,23 @@ Tensor& cat_out(
100100 out_shapes[i] = out_size[i];
101101 }
102102
103- bool optimized = true ;
104-
103+ bool all_same_dtype = true ;
105104 for (int i = 0 ; i < tensors.size (); i++) {
106105 if (out.scalar_type () != tensors[i].scalar_type ()) {
107- optimized = false ;
106+ all_same_dtype = false ;
108107 break ;
109108 }
110109 }
111110
112- if ((optimized) && (out.scalar_type () == ScalarType::Int) ||
113- (out.scalar_type () == ScalarType::Short) ||
114- (out.scalar_type () == ScalarType::Char) ||
115- (out.scalar_type () == ScalarType::UInt32) ||
116- (out.scalar_type () == ScalarType::UInt16) ||
117- (out.scalar_type () == ScalarType::Byte) ||
118- (out.scalar_type () == ScalarType::Float)) {
111+ bool supported_dtype = out.scalar_type () == ScalarType::Int ||
112+ out.scalar_type () == ScalarType::Short ||
113+ out.scalar_type () == ScalarType::Char ||
114+ out.scalar_type () == ScalarType::UInt32 ||
115+ out.scalar_type () == ScalarType::UInt16 ||
116+ out.scalar_type () == ScalarType::Byte ||
117+ out.scalar_type () == ScalarType::Float;
118+
119+ if (all_same_dtype && supported_dtype) {
119120 XT_KERNEL_CHECK (
120121 ctx,
121122 out,
You can’t perform that action at this time.
0 commit comments