Skip to content

Commit 0dd03a4

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Verify input dim order (#18643)
Summary: Add error checks around dim order and memory format. Differential Revision: D99129347
1 parent 36e8ed9 commit 0dd03a4

5 files changed

Lines changed: 137 additions & 2 deletions

File tree

extension/pybindings/test/test_pybindings.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,22 @@ def setUp(self):
5454
self.load_prog_fn = runtime._load_program_from_buffer
5555
self.runtime = runtime
5656

57+
def _make_non_dense_strided_input(self) -> torch.Tensor:
58+
strided_input = torch.arange(1 * 2 * 3 * 8, dtype=torch.float32).reshape(
59+
1, 2, 3, 8
60+
)[..., ::2]
61+
self.assertEqual(tuple(strided_input.shape), (1, 2, 3, 4))
62+
self.assertFalse(strided_input.is_contiguous())
63+
self.assertFalse(
64+
strided_input.is_contiguous(memory_format=torch.channels_last)
65+
)
66+
return strided_input
67+
68+
def _create_program_with_inputs(
69+
self, eager_module: torch.nn.Module, inputs: tuple[torch.Tensor, ...]
70+
):
71+
return to_edge(export(eager_module, inputs, strict=True)).to_executorch()
72+
5773
def test_e2e(self):
5874
exported_program, inputs = create_program(ModuleAdd())
5975
executorch_module = self.load_fn(exported_program.buffer)
@@ -204,6 +220,37 @@ def test_channels_last_in_default_out(self) -> None:
204220
expected = model(inputs[0])
205221
self.assertTrue(torch.allclose(expected, executorch_output))
206222

223+
def test_non_contiguous_input_dim_order_mismatch(self) -> None:
224+
"""Contiguous tensor passed to a channels-last model raises on dim order mismatch."""
225+
model = ModuleChannelsLast()
226+
exported_program, inputs = create_program(model)
227+
executorch_module = self.load_fn(exported_program.buffer)
228+
229+
# Model expects channels-last; passing contiguous triggers mismatch.
230+
contiguous_input = inputs[0].contiguous()
231+
self.assertRaises(RuntimeError, executorch_module, contiguous_input)
232+
233+
def test_channels_last_input_dim_order_mismatch(self) -> None:
234+
"""Channels-last tensor passed to a contiguous model raises on dim order mismatch."""
235+
model = ModuleChannelsLast()
236+
inputs = (torch.ones(1, 2, 3, 4),)
237+
exported_program = self._create_program_with_inputs(model, inputs)
238+
executorch_module = self.load_fn(exported_program.buffer)
239+
240+
channels_last_input = inputs[0].to(memory_format=torch.channels_last)
241+
self.assertRaises(RuntimeError, executorch_module, channels_last_input)
242+
243+
def test_strided_input_dim_order_error(self) -> None:
244+
model = ModuleChannelsLast()
245+
exported_program = self._create_program_with_inputs(
246+
model, (torch.ones(1, 2, 3, 4),)
247+
)
248+
executorch_module = self.load_fn(exported_program.buffer)
249+
250+
self.assertRaises(
251+
RuntimeError, executorch_module, self._make_non_dense_strided_input()
252+
)
253+
207254
def test_method_meta(self) -> None:
208255
exported_program, inputs = create_program(ModuleAdd())
209256

@@ -466,6 +513,40 @@ def test_method_channels_last_in_default_out(self) -> None:
466513
expected = model(inputs[0])
467514
self.assertTrue(torch.allclose(expected, executorch_output))
468515

516+
def test_method_non_contiguous_input_dim_order_mismatch(self) -> None:
517+
"""Contiguous tensor passed to a channels-last method raises on dim order mismatch."""
518+
model = ModuleChannelsLast()
519+
exported_program, inputs = create_program(model)
520+
executorch_program = self.load_prog_fn(exported_program.buffer)
521+
executorch_method = executorch_program.load_method("forward")
522+
523+
# Model expects channels-last; passing contiguous triggers mismatch.
524+
contiguous_input = inputs[0].contiguous()
525+
self.assertRaises(RuntimeError, executorch_method, contiguous_input)
526+
527+
def test_method_channels_last_input_dim_order_mismatch(self) -> None:
528+
"""Channels-last tensor passed to a contiguous method raises on dim order mismatch."""
529+
model = ModuleChannelsLast()
530+
inputs = (torch.ones(1, 2, 3, 4),)
531+
exported_program = self._create_program_with_inputs(model, inputs)
532+
executorch_program = self.load_prog_fn(exported_program.buffer)
533+
executorch_method = executorch_program.load_method("forward")
534+
535+
channels_last_input = inputs[0].to(memory_format=torch.channels_last)
536+
self.assertRaises(RuntimeError, executorch_method, channels_last_input)
537+
538+
def test_method_strided_input_dim_order_error(self) -> None:
539+
model = ModuleChannelsLast()
540+
exported_program = self._create_program_with_inputs(
541+
model, (torch.ones(1, 2, 3, 4),)
542+
)
543+
executorch_program = self.load_prog_fn(exported_program.buffer)
544+
executorch_method = executorch_program.load_method("forward")
545+
546+
self.assertRaises(
547+
RuntimeError, executorch_method, self._make_non_dense_strided_input()
548+
)
549+
469550
def test_method_bad_name(self) -> None:
470551
exported_program, inputs = create_program(ModuleAdd())
471552
executorch_program = self.load_prog_fn(exported_program.buffer)

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,13 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) {
11401140

11411141
/// These APIs should not be used outside of Executor.cpp.
11421142
namespace internal {
1143+
/**
1144+
* Validate that t_src matches the model's expected dim order.
1145+
*/
1146+
ET_NODISCARD Error check_tensor_data_layout(
1147+
const executorch::aten::Tensor& t_src,
1148+
executorch::runtime::Span<const uint8_t> expected_dim_order);
1149+
11431150
/**
11441151
* Share t_src's data_ptr with t_dst.
11451152
*/

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,20 @@ bool tensors_have_same_dim_order(
120120

121121
namespace internal {
122122

123-
Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
123+
Error check_tensor_data_layout(
124+
const at::Tensor& t_src,
125+
executorch::runtime::Span<const uint8_t> expected_dim_order) {
126+
// Theres some annoyance on teams doing weird stuff today thats hard to migrate.
127+
// ATen mode isnt supported in the CMake build and its really only used for
128+
// testing. So we can just skip this check for now.
129+
(void)t_src;
130+
(void)expected_dim_order;
131+
return Error::Ok;
132+
}
133+
134+
Error share_tensor_data(
135+
const at::Tensor& t_dst,
136+
const at::Tensor& t_src) {
124137
at::StorageImpl* storage =
125138
t_dst.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
126139

@@ -143,7 +156,9 @@ Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
143156
return Error::Ok;
144157
}
145158

146-
Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
159+
Error copy_tensor_data(
160+
const at::Tensor& t_dst,
161+
const at::Tensor& t_src) {
147162
void* dst_data_ptr = t_dst.unsafeGetTensorImpl()
148163
->unsafe_storage()
149164
.unsafeGetStorageImpl()

runtime/core/exec_aten/util/tensor_util_portable.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,32 @@ bool tensors_have_same_dim_order(
137137

138138
namespace internal {
139139

140+
Error check_tensor_data_layout(
141+
const torch::executor::Tensor& t_src,
142+
executorch::runtime::Span<const uint8_t> expected_dim_order) {
143+
// Input tensors don't actually have to have a dim order, so we can't check.
144+
// Tech debt from dim order being a later addition to the tensor spec.
145+
if (t_src.dim_order().data() == nullptr) {
146+
return Error::Ok;
147+
}
148+
149+
auto src_dim_order = t_src.dim_order();
150+
ET_CHECK_OR_RETURN_ERROR(
151+
expected_dim_order.size() == src_dim_order.size(),
152+
InvalidArgument,
153+
"Input dim order size (%zu) != expected (%zu).",
154+
src_dim_order.size(),
155+
expected_dim_order.size());
156+
ET_CHECK_OR_RETURN_ERROR(
157+
std::equal(
158+
expected_dim_order.begin(),
159+
expected_dim_order.end(),
160+
src_dim_order.data()),
161+
InvalidArgument,
162+
"Input dim order does not match the model.");
163+
return Error::Ok;
164+
}
165+
140166
Error share_tensor_data(
141167
const torch::executor::Tensor& t_dst,
142168
const torch::executor::Tensor& t_src) {

runtime/executor/method.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,12 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
11841184
"Error resizing tensor at input %" ET_PRIsize_t,
11851185
input_idx);
11861186
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
1187+
ET_CHECK_OK_OR_RETURN_ERROR(tensor_meta.error());
1188+
auto expected_dim_order = tensor_meta->dim_order();
1189+
ET_CHECK_OK_OR_RETURN_ERROR(
1190+
internal::check_tensor_data_layout(t_src, expected_dim_order),
1191+
"Error validating tensor layout at input %" ET_PRIsize_t,
1192+
input_idx);
11871193
if (tensor_meta->is_memory_planned()) {
11881194
ET_CHECK_OK_OR_RETURN_ERROR(
11891195
internal::copy_tensor_data(t_dst, t_src),

0 commit comments

Comments
 (0)