diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index ec45428c7d7..a317d77c9da 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -54,6 +54,20 @@ def setUp(self): self.load_prog_fn = runtime._load_program_from_buffer self.runtime = runtime + def _make_non_dense_strided_input(self) -> torch.Tensor: + strided_input = torch.arange(1 * 2 * 3 * 8, dtype=torch.float32).reshape( + 1, 2, 3, 8 + )[..., ::2] + self.assertEqual(tuple(strided_input.shape), (1, 2, 3, 4)) + self.assertFalse(strided_input.is_contiguous()) + self.assertFalse(strided_input.is_contiguous(memory_format=torch.channels_last)) + return strided_input + + def _create_program_with_inputs( + self, eager_module: torch.nn.Module, inputs: tuple[torch.Tensor, ...] + ): + return to_edge(export(eager_module, inputs, strict=True)).to_executorch() + def test_e2e(self): exported_program, inputs = create_program(ModuleAdd()) executorch_module = self.load_fn(exported_program.buffer) @@ -204,6 +218,37 @@ def test_channels_last_in_default_out(self) -> None: expected = model(inputs[0]) self.assertTrue(torch.allclose(expected, executorch_output)) + def test_non_contiguous_input_dim_order_mismatch(self) -> None: + """Contiguous tensor passed to a channels-last model raises on dim order mismatch.""" + model = ModuleChannelsLast() + exported_program, inputs = create_program(model) + executorch_module = self.load_fn(exported_program.buffer) + + # Model expects channels-last; passing contiguous triggers mismatch. + contiguous_input = inputs[0].contiguous() + self.assertRaises(RuntimeError, executorch_module, contiguous_input) + + def test_channels_last_input_dim_order_mismatch(self) -> None: + """Channels-last tensor passed to a contiguous model raises on dim order mismatch.""" + model = ModuleChannelsLast() + inputs = (torch.ones(1, 2, 3, 4),) + exported_program = self._create_program_with_inputs(model, inputs) + executorch_module = self.load_fn(exported_program.buffer) + + channels_last_input = inputs[0].to(memory_format=torch.channels_last) + self.assertRaises(RuntimeError, executorch_module, channels_last_input) + + def test_strided_input_dim_order_error(self) -> None: + model = ModuleChannelsLast() + exported_program = self._create_program_with_inputs( + model, (torch.ones(1, 2, 3, 4),) + ) + executorch_module = self.load_fn(exported_program.buffer) + + self.assertRaises( + RuntimeError, executorch_module, self._make_non_dense_strided_input() + ) + def test_method_meta(self) -> None: exported_program, inputs = create_program(ModuleAdd()) @@ -466,6 +511,40 @@ def test_method_channels_last_in_default_out(self) -> None: expected = model(inputs[0]) self.assertTrue(torch.allclose(expected, executorch_output)) + def test_method_non_contiguous_input_dim_order_mismatch(self) -> None: + """Contiguous tensor passed to a channels-last method raises on dim order mismatch.""" + model = ModuleChannelsLast() + exported_program, inputs = create_program(model) + executorch_program = self.load_prog_fn(exported_program.buffer) + executorch_method = executorch_program.load_method("forward") + + # Model expects channels-last; passing contiguous triggers mismatch. + contiguous_input = inputs[0].contiguous() + self.assertRaises(RuntimeError, executorch_method, contiguous_input) + + def test_method_channels_last_input_dim_order_mismatch(self) -> None: + """Channels-last tensor passed to a contiguous method raises on dim order mismatch.""" + model = ModuleChannelsLast() + inputs = (torch.ones(1, 2, 3, 4),) + exported_program = self._create_program_with_inputs(model, inputs) + executorch_program = self.load_prog_fn(exported_program.buffer) + executorch_method = executorch_program.load_method("forward") + + channels_last_input = inputs[0].to(memory_format=torch.channels_last) + self.assertRaises(RuntimeError, executorch_method, channels_last_input) + + def test_method_strided_input_dim_order_error(self) -> None: + model = ModuleChannelsLast() + exported_program = self._create_program_with_inputs( + model, (torch.ones(1, 2, 3, 4),) + ) + executorch_program = self.load_prog_fn(exported_program.buffer) + executorch_method = executorch_program.load_method("forward") + + self.assertRaises( + RuntimeError, executorch_method, self._make_non_dense_strided_input() + ) + def test_method_bad_name(self) -> None: exported_program, inputs = create_program(ModuleAdd()) executorch_program = self.load_prog_fn(exported_program.buffer) diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index 26b97e5a7a2..b9fcc07fba1 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -1140,6 +1140,13 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) { /// These APIs should not be used outside of Executor.cpp. namespace internal { +/** + * Validate that t_src matches the model's expected dim order. + */ +ET_NODISCARD Error check_tensor_data_layout( + const executorch::aten::Tensor& t_src, + executorch::runtime::Span expected_dim_order); + /** * Share t_src's data_ptr with t_dst. */ diff --git a/runtime/core/exec_aten/util/tensor_util_aten.cpp b/runtime/core/exec_aten/util/tensor_util_aten.cpp index b8d8e266016..d42c5b8446c 100644 --- a/runtime/core/exec_aten/util/tensor_util_aten.cpp +++ b/runtime/core/exec_aten/util/tensor_util_aten.cpp @@ -120,7 +120,20 @@ bool tensors_have_same_dim_order( namespace internal { -Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) { +Error check_tensor_data_layout( + const at::Tensor& t_src, + executorch::runtime::Span expected_dim_order) { + // Theres some annoyance on teams doing weird stuff today thats hard to + // migrate. ATen mode isnt supported in the CMake build and its really only + // used for testing. So we can just skip this check for now. + (void)t_src; + (void)expected_dim_order; + return Error::Ok; +} + +Error share_tensor_data( + const at::Tensor& t_dst, + const at::Tensor& t_src) { at::StorageImpl* storage = t_dst.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl(); @@ -143,7 +156,9 @@ Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) { return Error::Ok; } -Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) { +Error copy_tensor_data( + const at::Tensor& t_dst, + const at::Tensor& t_src) { void* dst_data_ptr = t_dst.unsafeGetTensorImpl() ->unsafe_storage() .unsafeGetStorageImpl() diff --git a/runtime/core/exec_aten/util/tensor_util_portable.cpp b/runtime/core/exec_aten/util/tensor_util_portable.cpp index 9626974ad7d..48493f159ad 100644 --- a/runtime/core/exec_aten/util/tensor_util_portable.cpp +++ b/runtime/core/exec_aten/util/tensor_util_portable.cpp @@ -137,6 +137,32 @@ bool tensors_have_same_dim_order( namespace internal { +Error check_tensor_data_layout( + const torch::executor::Tensor& t_src, + executorch::runtime::Span expected_dim_order) { + // Input tensors don't actually have to have a dim order, so we can't check. + // Tech debt from dim order being a later addition to the tensor spec. + if (t_src.dim_order().data() == nullptr) { + return Error::Ok; + } + + auto src_dim_order = t_src.dim_order(); + ET_CHECK_OR_RETURN_ERROR( + expected_dim_order.size() == src_dim_order.size(), + InvalidArgument, + "Input dim order size (%zu) != expected (%zu).", + src_dim_order.size(), + expected_dim_order.size()); + ET_CHECK_OR_RETURN_ERROR( + std::equal( + expected_dim_order.begin(), + expected_dim_order.end(), + src_dim_order.data()), + InvalidArgument, + "Input dim order does not match the model."); + return Error::Ok; +} + Error share_tensor_data( const torch::executor::Tensor& t_dst, const torch::executor::Tensor& t_src) { diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 606b2460155..6b55a0493ad 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1184,6 +1184,12 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) { "Error resizing tensor at input %" ET_PRIsize_t, input_idx); auto tensor_meta = this->method_meta().input_tensor_meta(input_idx); + ET_CHECK_OK_OR_RETURN_ERROR(tensor_meta.error()); + auto expected_dim_order = tensor_meta->dim_order(); + ET_CHECK_OK_OR_RETURN_ERROR( + internal::check_tensor_data_layout(t_src, expected_dim_order), + "Error validating tensor layout at input %" ET_PRIsize_t, + input_idx); if (tensor_meta->is_memory_planned()) { ET_CHECK_OK_OR_RETURN_ERROR( internal::copy_tensor_data(t_dst, t_src),