Skip to content

Commit fe33ca5

Browse files
author
NefAI
committed
fix(runtime): tensors_have_same_dim_order Tier A legacy + Tier B semantic (Alternative A)
- Tier A: restore all_contiguous || all_channels_last over full list (pre-PR C contract) so mixed-rank and broadcast call sites pass. - Tier B: when Tier A fails, same-rank tensors use semantic equivalence to ref; different-rank must match ref contiguous vs channels_last family. - Parity: apply same logic in tensor_util_aten.cpp (two_tensors_semantic_same_layout). - Docs in tensor_util.h; rename different-rank test to DifferentRankSameLegacyFormatFamilyPasses. Fixes unittest-editable regressions from pairwise-only dim_order checks while preserving #16032 degenerate-shape behavior.
1 parent 0a12cfc commit fe33ca5

4 files changed

Lines changed: 158 additions & 40 deletions

File tree

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,37 +1280,34 @@ bool tensor_is_default_dim_order(executorch::aten::Tensor t);
12801280
bool tensor_is_channels_last_dim_order(executorch::aten::Tensor t);
12811281

12821282
/**
1283-
* Asserts that four tensors have the same dim_order
1283+
* Returns true if all tensors are in a compatible layout for portable kernels.
12841284
*
1285-
* Note that this macro only tests dim order, but not others like actual data,
1286-
* sizes, etc.
1285+
* First, the legacy rule: either every tensor is contiguous-order
1286+
* (`is_contiguous_dim_order`) or every tensor is channels-last-order
1287+
* (`is_channels_last_dim_order`). That matches mixed-rank argument lists
1288+
* (e.g. batch norm with reduced outputs), broadcast shapes, and typical
1289+
* elementwise ops.
12871290
*
1291+
* If that fails, falls back to semantic equivalence for tensors with the same
1292+
* rank as the first tensor: matching dim_order labels, or matching strides on
1293+
* non-size-1 dimensions (degenerate-shape / ambiguous dim_order cases).
1294+
* Tensors with a different rank than the first must match the first tensor's
1295+
* format family (both contiguous-order, or both channels-last-order).
1296+
*
1297+
* Does not validate sizes, dtypes, or data.
12881298
*/
12891299
bool tensors_have_same_dim_order(
12901300
const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list);
12911301

1292-
/**
1293-
* Asserts that two tensors have the same dim_order
1294-
*
1295-
* Note that this macro only tests dim order, but not others like actual data,
1296-
* sizes, etc.
1297-
*/
1298-
1302+
/** @see tensors_have_same_dim_order(ArrayRef) */
12991303
inline bool tensors_have_same_dim_order(
13001304
const executorch::aten::Tensor& a,
13011305
const executorch::aten::Tensor& b) {
13021306
executorch::aten::Tensor tensor_list[2] = {a, b};
13031307
return tensors_have_same_dim_order(tensor_list);
13041308
}
13051309

1306-
/**
1307-
* Asserts that three tensors have the same dim_order
1308-
*
1309-
* Note that this macro only tests dim order, but not others like actual data,
1310-
* sizes, etc.
1311-
*
1312-
*/
1313-
1310+
/** @see tensors_have_same_dim_order(ArrayRef) */
13141311
inline bool tensors_have_same_dim_order(
13151312
const executorch::aten::Tensor& a,
13161313
const executorch::aten::Tensor& b,
@@ -1319,14 +1316,7 @@ inline bool tensors_have_same_dim_order(
13191316
return tensors_have_same_dim_order(tensor_list);
13201317
}
13211318

1322-
/**
1323-
* Asserts that four tensors have the same dim_order
1324-
*
1325-
* Note that this macro only tests dim order, but not others like actual data,
1326-
* sizes, etc.
1327-
*
1328-
*/
1329-
1319+
/** @see tensors_have_same_dim_order(ArrayRef) */
13301320
inline bool tensors_have_same_dim_order(
13311321
const executorch::aten::Tensor& a,
13321322
const executorch::aten::Tensor& b,

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,46 @@ inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) {
7878
return ret_val;
7979
}
8080

81+
namespace {
82+
83+
// Same-rank semantic layout match (dim_order labels, else strides with
84+
// size-1 dims skipped). Used when the legacy format-family check fails.
85+
bool two_tensors_semantic_same_layout(
86+
const executorch::aten::Tensor& a,
87+
const executorch::aten::Tensor& b) {
88+
if (a.dim() != b.dim()) {
89+
return false;
90+
}
91+
const int ndim = static_cast<int>(a.dim());
92+
executorch::aten::DimOrderType order_a[kTensorDimensionLimit];
93+
executorch::aten::DimOrderType order_b[kTensorDimensionLimit];
94+
if (get_dim_order(a, order_a, a.dim()) != Error::Ok ||
95+
get_dim_order(b, order_b, b.dim()) != Error::Ok) {
96+
return false;
97+
}
98+
bool labels_match = true;
99+
for (int i = 0; i < ndim; ++i) {
100+
if (order_a[i] != order_b[i]) {
101+
labels_match = false;
102+
break;
103+
}
104+
}
105+
if (labels_match) {
106+
return true;
107+
}
108+
for (int i = 0; i < ndim; ++i) {
109+
if (a.size(i) == 1 && b.size(i) == 1) {
110+
continue;
111+
}
112+
if (a.stride(i) != b.stride(i)) {
113+
return false;
114+
}
115+
}
116+
return true;
117+
}
118+
119+
} // namespace
120+
81121
bool tensors_have_same_dim_order(
82122
const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list) {
83123
if (tensor_list.size() < 2) {
@@ -110,12 +150,50 @@ bool tensors_have_same_dim_order(
110150
is_channels_last_dim_order(other_dim_order, tensor_list[i].dim());
111151
}
112152

113-
ET_CHECK_OR_RETURN_FALSE(
114-
all_contiguous || all_channels_last,
115-
"%zd input tensors have different dim orders",
116-
tensor_list.size());
153+
if (all_contiguous || all_channels_last) {
154+
return true;
155+
}
156+
157+
const executorch::aten::Tensor& ref = tensor_list[0];
158+
const bool ref_contiguous =
159+
is_contiguous_dim_order(first_dim_order, ref.dim());
160+
const bool ref_channels_last =
161+
is_channels_last_dim_order(first_dim_order, ref.dim());
117162

118-
return all_contiguous || all_channels_last;
163+
for (size_t i = 1; i < tensor_list.size(); ++i) {
164+
const executorch::aten::Tensor& t = tensor_list[i];
165+
if (t.dim() == ref.dim()) {
166+
if (!two_tensors_semantic_same_layout(ref, t)) {
167+
ET_LOG(
168+
Error,
169+
"%zd input tensors have different dim orders",
170+
tensor_list.size());
171+
return false;
172+
}
173+
} else {
174+
if (get_dim_order(t, other_dim_order, t.dim()) != Error::Ok) {
175+
ET_LOG(
176+
Error,
177+
"%zd input tensors have different dim orders",
178+
tensor_list.size());
179+
return false;
180+
}
181+
const bool t_contiguous =
182+
is_contiguous_dim_order(other_dim_order, t.dim());
183+
const bool t_channels_last =
184+
is_channels_last_dim_order(other_dim_order, t.dim());
185+
const bool ok = (ref_contiguous && t_contiguous) ||
186+
(ref_channels_last && t_channels_last);
187+
if (!ok) {
188+
ET_LOG(
189+
Error,
190+
"%zd input tensors have different dim orders",
191+
tensor_list.size());
192+
return false;
193+
}
194+
}
195+
}
196+
return true;
119197
}
120198

121199
namespace internal {

runtime/core/exec_aten/util/tensor_util_portable.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,69 @@ bool two_tensors_same_dim_order(
154154
return true;
155155
}
156156

157+
// Tier A: every tensor is contiguous-order or every tensor is channels-last
158+
// (original portable contract). Handles mixed rank, broadcast shapes, and
159+
// reduced aux outputs (e.g. batch norm mean tensors).
160+
bool tensors_share_legacy_format_family(
161+
const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list) {
162+
bool all_contiguous = true;
163+
bool all_channels_last = true;
164+
for (const auto i : c10::irange(tensor_list.size())) {
165+
all_contiguous = all_contiguous &&
166+
is_contiguous_dim_order(
167+
tensor_list[i].dim_order().data(),
168+
tensor_list[i].dim_order().size());
169+
all_channels_last = all_channels_last &&
170+
is_channels_last_dim_order(
171+
tensor_list[i].dim_order().data(),
172+
tensor_list[i].dim_order().size());
173+
}
174+
return all_contiguous || all_channels_last;
175+
}
176+
157177
} // namespace
158178

159179
bool tensors_have_same_dim_order(
160180
const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list) {
161181
if (tensor_list.size() < 2) {
162182
return true;
163183
}
184+
185+
if (tensors_share_legacy_format_family(tensor_list)) {
186+
return true;
187+
}
188+
189+
const executorch::aten::Tensor& ref = tensor_list[0];
190+
const bool ref_contiguous =
191+
is_contiguous_dim_order(ref.dim_order().data(), ref.dim_order().size());
192+
const bool ref_channels_last =
193+
is_channels_last_dim_order(ref.dim_order().data(), ref.dim_order().size());
194+
164195
for (size_t i = 1; i < tensor_list.size(); ++i) {
165-
if (!two_tensors_same_dim_order(tensor_list[0], tensor_list[i])) {
166-
ET_LOG(
167-
Error,
168-
"%zd input tensors have different dim orders",
169-
tensor_list.size());
170-
return false;
196+
const executorch::aten::Tensor& t = tensor_list[i];
197+
if (t.dim() == ref.dim()) {
198+
if (!two_tensors_same_dim_order(ref, t)) {
199+
ET_LOG(
200+
Error,
201+
"%zd input tensors have different dim orders",
202+
tensor_list.size());
203+
return false;
204+
}
205+
} else {
206+
const bool t_contiguous =
207+
is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size());
208+
const bool t_channels_last =
209+
is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size());
210+
const bool ok =
211+
(ref_contiguous && t_contiguous) ||
212+
(ref_channels_last && t_channels_last);
213+
if (!ok) {
214+
ET_LOG(
215+
Error,
216+
"%zd input tensors have different dim orders",
217+
tensor_list.size());
218+
return false;
219+
}
171220
}
172221
}
173222
return true;

runtime/core/exec_aten/util/test/tensor_util_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,13 +692,14 @@ TEST_F(TensorUtilTest, SemanticEquivalencePartialDegenerateFails) {
692692
EXPECT_FALSE(tensors_have_same_dim_order(nchw, nhwc));
693693
}
694694

695-
TEST_F(TensorUtilTest, SemanticEquivalenceDifferentRankFails) {
695+
TEST_F(TensorUtilTest, DifferentRankSameLegacyFormatFamilyPasses) {
696696
using namespace torch::executor;
697-
// Different ranks should fail
697+
// Legacy rule: all contiguous-order (or all channels-last) passes even when
698+
// ranks differ (e.g. reduced outputs vs full activations).
698699
Tensor a = tf_float_.ones({2, 3, 4, 4});
699700
Tensor b = tf_float_.ones({2, 3, 4});
700701

701-
EXPECT_FALSE(tensors_have_same_dim_order(a, b));
702+
EXPECT_TRUE(tensors_have_same_dim_order(a, b));
702703
}
703704

704705
TEST_F(TensorUtilTest, SemanticEquivalenceSameLabelsSameResult) {

0 commit comments

Comments
 (0)