Skip to content

Commit 91bc89b

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Verify input dim order (#18643)
Summary: Add error checks around dim order and memory format. Differential Revision: D99129347
1 parent 5e8a0df commit 91bc89b

5 files changed

Lines changed: 135 additions & 2 deletions

File tree

extension/pybindings/test/test_pybindings.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def setUp(self):
5454
self.load_prog_fn = runtime._load_program_from_buffer
5555
self.runtime = runtime
5656

57+
def _make_non_dense_strided_input(self) -> torch.Tensor:
58+
strided_input = torch.arange(1 * 2 * 3 * 8, dtype=torch.float32).reshape(
59+
1, 2, 3, 8
60+
)[..., ::2]
61+
self.assertEqual(tuple(strided_input.shape), (1, 2, 3, 4))
62+
self.assertFalse(strided_input.is_contiguous())
63+
self.assertFalse(strided_input.is_contiguous(memory_format=torch.channels_last))
64+
return strided_input
65+
66+
def _create_program_with_inputs(
67+
self, eager_module: torch.nn.Module, inputs: tuple[torch.Tensor, ...]
68+
):
69+
return to_edge(export(eager_module, inputs, strict=True)).to_executorch()
70+
5771
def test_e2e(self):
5872
exported_program, inputs = create_program(ModuleAdd())
5973
executorch_module = self.load_fn(exported_program.buffer)
@@ -204,6 +218,37 @@ def test_channels_last_in_default_out(self) -> None:
204218
expected = model(inputs[0])
205219
self.assertTrue(torch.allclose(expected, executorch_output))
206220

221+
def test_non_contiguous_input_dim_order_mismatch(self) -> None:
222+
"""Contiguous tensor passed to a channels-last model raises on dim order mismatch."""
223+
model = ModuleChannelsLast()
224+
exported_program, inputs = create_program(model)
225+
executorch_module = self.load_fn(exported_program.buffer)
226+
227+
# Model expects channels-last; passing contiguous triggers mismatch.
228+
contiguous_input = inputs[0].contiguous()
229+
self.assertRaises(RuntimeError, executorch_module, contiguous_input)
230+
231+
def test_channels_last_input_dim_order_mismatch(self) -> None:
232+
"""Channels-last tensor passed to a contiguous model raises on dim order mismatch."""
233+
model = ModuleChannelsLast()
234+
inputs = (torch.ones(1, 2, 3, 4),)
235+
exported_program = self._create_program_with_inputs(model, inputs)
236+
executorch_module = self.load_fn(exported_program.buffer)
237+
238+
channels_last_input = inputs[0].to(memory_format=torch.channels_last)
239+
self.assertRaises(RuntimeError, executorch_module, channels_last_input)
240+
241+
def test_strided_input_dim_order_error(self) -> None:
242+
model = ModuleChannelsLast()
243+
exported_program = self._create_program_with_inputs(
244+
model, (torch.ones(1, 2, 3, 4),)
245+
)
246+
executorch_module = self.load_fn(exported_program.buffer)
247+
248+
self.assertRaises(
249+
RuntimeError, executorch_module, self._make_non_dense_strided_input()
250+
)
251+
207252
def test_method_meta(self) -> None:
208253
exported_program, inputs = create_program(ModuleAdd())
209254

@@ -466,6 +511,40 @@ def test_method_channels_last_in_default_out(self) -> None:
466511
expected = model(inputs[0])
467512
self.assertTrue(torch.allclose(expected, executorch_output))
468513

514+
def test_method_non_contiguous_input_dim_order_mismatch(self) -> None:
515+
"""Contiguous tensor passed to a channels-last method raises on dim order mismatch."""
516+
model = ModuleChannelsLast()
517+
exported_program, inputs = create_program(model)
518+
executorch_program = self.load_prog_fn(exported_program.buffer)
519+
executorch_method = executorch_program.load_method("forward")
520+
521+
# Model expects channels-last; passing contiguous triggers mismatch.
522+
contiguous_input = inputs[0].contiguous()
523+
self.assertRaises(RuntimeError, executorch_method, contiguous_input)
524+
525+
def test_method_channels_last_input_dim_order_mismatch(self) -> None:
526+
"""Channels-last tensor passed to a contiguous method raises on dim order mismatch."""
527+
model = ModuleChannelsLast()
528+
inputs = (torch.ones(1, 2, 3, 4),)
529+
exported_program = self._create_program_with_inputs(model, inputs)
530+
executorch_program = self.load_prog_fn(exported_program.buffer)
531+
executorch_method = executorch_program.load_method("forward")
532+
533+
channels_last_input = inputs[0].to(memory_format=torch.channels_last)
534+
self.assertRaises(RuntimeError, executorch_method, channels_last_input)
535+
536+
def test_method_strided_input_dim_order_error(self) -> None:
537+
model = ModuleChannelsLast()
538+
exported_program = self._create_program_with_inputs(
539+
model, (torch.ones(1, 2, 3, 4),)
540+
)
541+
executorch_program = self.load_prog_fn(exported_program.buffer)
542+
executorch_method = executorch_program.load_method("forward")
543+
544+
self.assertRaises(
545+
RuntimeError, executorch_method, self._make_non_dense_strided_input()
546+
)
547+
469548
def test_method_bad_name(self) -> None:
470549
exported_program, inputs = create_program(ModuleAdd())
471550
executorch_program = self.load_prog_fn(exported_program.buffer)

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,13 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) {
11401140

11411141
/// These APIs should not be used outside of Executor.cpp.
11421142
namespace internal {
1143+
/**
1144+
* Validate that t_src matches the model's expected dim order.
1145+
*/
1146+
ET_NODISCARD Error check_tensor_data_layout(
1147+
const executorch::aten::Tensor& t_src,
1148+
executorch::runtime::Span<const uint8_t> expected_dim_order);
1149+
11431150
/**
11441151
* Share t_src's data_ptr with t_dst.
11451152
*/

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,20 @@ bool tensors_have_same_dim_order(
120120

121121
namespace internal {
122122

123-
Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
123+
Error check_tensor_data_layout(
124+
const at::Tensor& t_src,
125+
executorch::runtime::Span<const uint8_t> expected_dim_order) {
126+
// Theres some annoyance on teams doing weird stuff today thats hard to
127+
// migrate. ATen mode isnt supported in the CMake build and its really only
128+
// used for testing. So we can just skip this check for now.
129+
(void)t_src;
130+
(void)expected_dim_order;
131+
return Error::Ok;
132+
}
133+
134+
Error share_tensor_data(
135+
const at::Tensor& t_dst,
136+
const at::Tensor& t_src) {
124137
at::StorageImpl* storage =
125138
t_dst.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
126139

@@ -143,7 +156,9 @@ Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
143156
return Error::Ok;
144157
}
145158

146-
Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
159+
Error copy_tensor_data(
160+
const at::Tensor& t_dst,
161+
const at::Tensor& t_src) {
147162
void* dst_data_ptr = t_dst.unsafeGetTensorImpl()
148163
->unsafe_storage()
149164
.unsafeGetStorageImpl()

runtime/core/exec_aten/util/tensor_util_portable.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,32 @@ bool tensors_have_same_dim_order(
137137

138138
namespace internal {
139139

140+
Error check_tensor_data_layout(
141+
const torch::executor::Tensor& t_src,
142+
executorch::runtime::Span<const uint8_t> expected_dim_order) {
143+
// Input tensors don't actually have to have a dim order, so we can't check.
144+
// Tech debt from dim order being a later addition to the tensor spec.
145+
if (t_src.dim_order().data() == nullptr) {
146+
return Error::Ok;
147+
}
148+
149+
auto src_dim_order = t_src.dim_order();
150+
ET_CHECK_OR_RETURN_ERROR(
151+
expected_dim_order.size() == src_dim_order.size(),
152+
InvalidArgument,
153+
"Input dim order size (%zu) != expected (%zu).",
154+
src_dim_order.size(),
155+
expected_dim_order.size());
156+
ET_CHECK_OR_RETURN_ERROR(
157+
std::equal(
158+
expected_dim_order.begin(),
159+
expected_dim_order.end(),
160+
src_dim_order.data()),
161+
InvalidArgument,
162+
"Input dim order does not match the model.");
163+
return Error::Ok;
164+
}
165+
140166
Error share_tensor_data(
141167
const torch::executor::Tensor& t_dst,
142168
const torch::executor::Tensor& t_src) {

runtime/executor/method.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,12 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
11841184
"Error resizing tensor at input %" ET_PRIsize_t,
11851185
input_idx);
11861186
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
1187+
ET_CHECK_OK_OR_RETURN_ERROR(tensor_meta.error());
1188+
auto expected_dim_order = tensor_meta->dim_order();
1189+
ET_CHECK_OK_OR_RETURN_ERROR(
1190+
internal::check_tensor_data_layout(t_src, expected_dim_order),
1191+
"Error validating tensor layout at input %" ET_PRIsize_t,
1192+
input_idx);
11871193
if (tensor_meta->is_memory_planned()) {
11881194
ET_CHECK_OK_OR_RETURN_ERROR(
11891195
internal::copy_tensor_data(t_dst, t_src),

0 commit comments

Comments
 (0)