@@ -54,6 +54,20 @@ def setUp(self):
5454 self .load_prog_fn = runtime ._load_program_from_buffer
5555 self .runtime = runtime
5656
57+ def _make_non_dense_strided_input (self ) -> torch .Tensor :
58+ strided_input = torch .arange (1 * 2 * 3 * 8 , dtype = torch .float32 ).reshape (
59+ 1 , 2 , 3 , 8
60+ )[..., ::2 ]
61+ self .assertEqual (tuple (strided_input .shape ), (1 , 2 , 3 , 4 ))
62+ self .assertFalse (strided_input .is_contiguous ())
63+ self .assertFalse (strided_input .is_contiguous (memory_format = torch .channels_last ))
64+ return strided_input
65+
66+ def _create_program_with_inputs (
67+ self , eager_module : torch .nn .Module , inputs : tuple [torch .Tensor , ...]
68+ ):
69+ return to_edge (export (eager_module , inputs , strict = True )).to_executorch ()
70+
5771 def test_e2e (self ):
5872 exported_program , inputs = create_program (ModuleAdd ())
5973 executorch_module = self .load_fn (exported_program .buffer )
@@ -204,6 +218,37 @@ def test_channels_last_in_default_out(self) -> None:
204218 expected = model (inputs [0 ])
205219 self .assertTrue (torch .allclose (expected , executorch_output ))
206220
221+ def test_non_contiguous_input_dim_order_mismatch (self ) -> None :
222+ """Contiguous tensor passed to a channels-last model raises on dim order mismatch."""
223+ model = ModuleChannelsLast ()
224+ exported_program , inputs = create_program (model )
225+ executorch_module = self .load_fn (exported_program .buffer )
226+
227+ # Model expects channels-last; passing contiguous triggers mismatch.
228+ contiguous_input = inputs [0 ].contiguous ()
229+ self .assertRaises (RuntimeError , executorch_module , contiguous_input )
230+
231+ def test_channels_last_input_dim_order_mismatch (self ) -> None :
232+ """Channels-last tensor passed to a contiguous model raises on dim order mismatch."""
233+ model = ModuleChannelsLast ()
234+ inputs = (torch .ones (1 , 2 , 3 , 4 ),)
235+ exported_program = self ._create_program_with_inputs (model , inputs )
236+ executorch_module = self .load_fn (exported_program .buffer )
237+
238+ channels_last_input = inputs [0 ].to (memory_format = torch .channels_last )
239+ self .assertRaises (RuntimeError , executorch_module , channels_last_input )
240+
241+ def test_strided_input_dim_order_error (self ) -> None :
242+ model = ModuleChannelsLast ()
243+ exported_program = self ._create_program_with_inputs (
244+ model , (torch .ones (1 , 2 , 3 , 4 ),)
245+ )
246+ executorch_module = self .load_fn (exported_program .buffer )
247+
248+ self .assertRaises (
249+ RuntimeError , executorch_module , self ._make_non_dense_strided_input ()
250+ )
251+
207252 def test_method_meta (self ) -> None :
208253 exported_program , inputs = create_program (ModuleAdd ())
209254
@@ -466,6 +511,40 @@ def test_method_channels_last_in_default_out(self) -> None:
466511 expected = model (inputs [0 ])
467512 self .assertTrue (torch .allclose (expected , executorch_output ))
468513
514+ def test_method_non_contiguous_input_dim_order_mismatch (self ) -> None :
515+ """Contiguous tensor passed to a channels-last method raises on dim order mismatch."""
516+ model = ModuleChannelsLast ()
517+ exported_program , inputs = create_program (model )
518+ executorch_program = self .load_prog_fn (exported_program .buffer )
519+ executorch_method = executorch_program .load_method ("forward" )
520+
521+ # Model expects channels-last; passing contiguous triggers mismatch.
522+ contiguous_input = inputs [0 ].contiguous ()
523+ self .assertRaises (RuntimeError , executorch_method , contiguous_input )
524+
525+ def test_method_channels_last_input_dim_order_mismatch (self ) -> None :
526+ """Channels-last tensor passed to a contiguous method raises on dim order mismatch."""
527+ model = ModuleChannelsLast ()
528+ inputs = (torch .ones (1 , 2 , 3 , 4 ),)
529+ exported_program = self ._create_program_with_inputs (model , inputs )
530+ executorch_program = self .load_prog_fn (exported_program .buffer )
531+ executorch_method = executorch_program .load_method ("forward" )
532+
533+ channels_last_input = inputs [0 ].to (memory_format = torch .channels_last )
534+ self .assertRaises (RuntimeError , executorch_method , channels_last_input )
535+
536+ def test_method_strided_input_dim_order_error (self ) -> None :
537+ model = ModuleChannelsLast ()
538+ exported_program = self ._create_program_with_inputs (
539+ model , (torch .ones (1 , 2 , 3 , 4 ),)
540+ )
541+ executorch_program = self .load_prog_fn (exported_program .buffer )
542+ executorch_method = executorch_program .load_method ("forward" )
543+
544+ self .assertRaises (
545+ RuntimeError , executorch_method , self ._make_non_dense_strided_input ()
546+ )
547+
469548 def test_method_bad_name (self ) -> None :
470549 exported_program , inputs = create_program (ModuleAdd ())
471550 executorch_program = self .load_prog_fn (exported_program .buffer )
0 commit comments