Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions extension/pybindings/test/test_pybindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def setUp(self):
self.load_prog_fn = runtime._load_program_from_buffer
self.runtime = runtime

def _make_non_dense_strided_input(self) -> torch.Tensor:
strided_input = torch.arange(1 * 2 * 3 * 8, dtype=torch.float32).reshape(
1, 2, 3, 8
)[..., ::2]
self.assertEqual(tuple(strided_input.shape), (1, 2, 3, 4))
self.assertFalse(strided_input.is_contiguous())
self.assertFalse(strided_input.is_contiguous(memory_format=torch.channels_last))
return strided_input

def _create_program_with_inputs(
self, eager_module: torch.nn.Module, inputs: tuple[torch.Tensor, ...]
):
return to_edge(export(eager_module, inputs, strict=True)).to_executorch()

def test_e2e(self):
exported_program, inputs = create_program(ModuleAdd())
executorch_module = self.load_fn(exported_program.buffer)
Expand Down Expand Up @@ -204,6 +218,37 @@ def test_channels_last_in_default_out(self) -> None:
expected = model(inputs[0])
self.assertTrue(torch.allclose(expected, executorch_output))

def test_non_contiguous_input_dim_order_mismatch(self) -> None:
"""Contiguous tensor passed to a channels-last model raises on dim order mismatch."""
model = ModuleChannelsLast()
exported_program, inputs = create_program(model)
executorch_module = self.load_fn(exported_program.buffer)

# Model expects channels-last; passing contiguous triggers mismatch.
contiguous_input = inputs[0].contiguous()
self.assertRaises(RuntimeError, executorch_module, contiguous_input)

def test_channels_last_input_dim_order_mismatch(self) -> None:
"""Channels-last tensor passed to a contiguous model raises on dim order mismatch."""
model = ModuleChannelsLast()
inputs = (torch.ones(1, 2, 3, 4),)
exported_program = self._create_program_with_inputs(model, inputs)
executorch_module = self.load_fn(exported_program.buffer)

channels_last_input = inputs[0].to(memory_format=torch.channels_last)
self.assertRaises(RuntimeError, executorch_module, channels_last_input)

def test_strided_input_dim_order_error(self) -> None:
model = ModuleChannelsLast()
exported_program = self._create_program_with_inputs(
model, (torch.ones(1, 2, 3, 4),)
)
executorch_module = self.load_fn(exported_program.buffer)

self.assertRaises(
RuntimeError, executorch_module, self._make_non_dense_strided_input()
)

def test_method_meta(self) -> None:
exported_program, inputs = create_program(ModuleAdd())

Expand Down Expand Up @@ -466,6 +511,40 @@ def test_method_channels_last_in_default_out(self) -> None:
expected = model(inputs[0])
self.assertTrue(torch.allclose(expected, executorch_output))

def test_method_non_contiguous_input_dim_order_mismatch(self) -> None:
"""Contiguous tensor passed to a channels-last method raises on dim order mismatch."""
model = ModuleChannelsLast()
exported_program, inputs = create_program(model)
executorch_program = self.load_prog_fn(exported_program.buffer)
executorch_method = executorch_program.load_method("forward")

# Model expects channels-last; passing contiguous triggers mismatch.
contiguous_input = inputs[0].contiguous()
self.assertRaises(RuntimeError, executorch_method, contiguous_input)

def test_method_channels_last_input_dim_order_mismatch(self) -> None:
"""Channels-last tensor passed to a contiguous method raises on dim order mismatch."""
model = ModuleChannelsLast()
inputs = (torch.ones(1, 2, 3, 4),)
exported_program = self._create_program_with_inputs(model, inputs)
executorch_program = self.load_prog_fn(exported_program.buffer)
executorch_method = executorch_program.load_method("forward")

channels_last_input = inputs[0].to(memory_format=torch.channels_last)
self.assertRaises(RuntimeError, executorch_method, channels_last_input)

def test_method_strided_input_dim_order_error(self) -> None:
model = ModuleChannelsLast()
exported_program = self._create_program_with_inputs(
model, (torch.ones(1, 2, 3, 4),)
)
executorch_program = self.load_prog_fn(exported_program.buffer)
executorch_method = executorch_program.load_method("forward")

self.assertRaises(
RuntimeError, executorch_method, self._make_non_dense_strided_input()
)

def test_method_bad_name(self) -> None:
exported_program, inputs = create_program(ModuleAdd())
executorch_program = self.load_prog_fn(exported_program.buffer)
Expand Down
7 changes: 7 additions & 0 deletions runtime/core/exec_aten/util/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,13 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) {

/// These APIs should not be used outside of Executor.cpp.
namespace internal {
/**
* Validate that t_src matches the model's expected dim order.
*/
ET_NODISCARD Error check_tensor_data_layout(
const executorch::aten::Tensor& t_src,
executorch::runtime::Span<const uint8_t> expected_dim_order);

/**
* Share t_src's data_ptr with t_dst.
*/
Expand Down
19 changes: 17 additions & 2 deletions runtime/core/exec_aten/util/tensor_util_aten.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -120,7 +120,20 @@

namespace internal {

Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
Error check_tensor_data_layout(
const at::Tensor& t_src,
executorch::runtime::Span<const uint8_t> expected_dim_order) {
// Theres some annoyance on teams doing weird stuff today thats hard to
// migrate. ATen mode isnt supported in the CMake build and its really only
// used for testing. So we can just skip this check for now.
(void)t_src;
(void)expected_dim_order;
return Error::Ok;
}

Error share_tensor_data(
const at::Tensor& t_dst,
const at::Tensor& t_src) {
at::StorageImpl* storage =
t_dst.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();

Expand All @@ -143,7 +156,9 @@
return Error::Ok;
}

Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
Error copy_tensor_data(
const at::Tensor& t_dst,
const at::Tensor& t_src) {
void* dst_data_ptr = t_dst.unsafeGetTensorImpl()
->unsafe_storage()
.unsafeGetStorageImpl()
Expand Down
26 changes: 26 additions & 0 deletions runtime/core/exec_aten/util/tensor_util_portable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ bool tensors_have_same_dim_order(

namespace internal {

Error check_tensor_data_layout(
const torch::executor::Tensor& t_src,
executorch::runtime::Span<const uint8_t> expected_dim_order) {
// Input tensors don't actually have to have a dim order, so we can't check.
// Tech debt from dim order being a later addition to the tensor spec.
if (t_src.dim_order().data() == nullptr) {
return Error::Ok;
}

auto src_dim_order = t_src.dim_order();
ET_CHECK_OR_RETURN_ERROR(
expected_dim_order.size() == src_dim_order.size(),
InvalidArgument,
"Input dim order size (%zu) != expected (%zu).",
src_dim_order.size(),
expected_dim_order.size());
ET_CHECK_OR_RETURN_ERROR(
std::equal(
expected_dim_order.begin(),
expected_dim_order.end(),
src_dim_order.data()),
InvalidArgument,
"Input dim order does not match the model.");
return Error::Ok;
}

Error share_tensor_data(
const torch::executor::Tensor& t_dst,
const torch::executor::Tensor& t_src) {
Expand Down
6 changes: 6 additions & 0 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,12 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
"Error resizing tensor at input %" ET_PRIsize_t,
input_idx);
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
ET_CHECK_OK_OR_RETURN_ERROR(tensor_meta.error());
auto expected_dim_order = tensor_meta->dim_order();
ET_CHECK_OK_OR_RETURN_ERROR(
internal::check_tensor_data_layout(t_src, expected_dim_order),
"Error validating tensor layout at input %" ET_PRIsize_t,
input_idx);
if (tensor_meta->is_memory_planned()) {
ET_CHECK_OK_OR_RETURN_ERROR(
internal::copy_tensor_data(t_dst, t_src),
Expand Down
Loading