@@ -54,6 +54,22 @@ def setUp(self):
5454 self .load_prog_fn = runtime ._load_program_from_buffer
5555 self .runtime = runtime
5656
57+ def _make_non_dense_strided_input (self ) -> torch .Tensor :
58+ strided_input = torch .arange (1 * 2 * 3 * 8 , dtype = torch .float32 ).reshape (
59+ 1 , 2 , 3 , 8
60+ )[..., ::2 ]
61+ self .assertEqual (tuple (strided_input .shape ), (1 , 2 , 3 , 4 ))
62+ self .assertFalse (strided_input .is_contiguous ())
63+ self .assertFalse (
64+ strided_input .is_contiguous (memory_format = torch .channels_last )
65+ )
66+ return strided_input
67+
68+ def _create_program_with_inputs (
69+ self , eager_module : torch .nn .Module , inputs : tuple [torch .Tensor , ...]
70+ ):
71+ return to_edge (export (eager_module , inputs , strict = True )).to_executorch ()
72+
5773 def test_e2e (self ):
5874 exported_program , inputs = create_program (ModuleAdd ())
5975 executorch_module = self .load_fn (exported_program .buffer )
@@ -204,6 +220,37 @@ def test_channels_last_in_default_out(self) -> None:
204220 expected = model (inputs [0 ])
205221 self .assertTrue (torch .allclose (expected , executorch_output ))
206222
223+ def test_non_contiguous_input_dim_order_mismatch (self ) -> None :
224+ """Contiguous tensor passed to a channels-last model raises on dim order mismatch."""
225+ model = ModuleChannelsLast ()
226+ exported_program , inputs = create_program (model )
227+ executorch_module = self .load_fn (exported_program .buffer )
228+
229+ # Model expects channels-last; passing contiguous triggers mismatch.
230+ contiguous_input = inputs [0 ].contiguous ()
231+ self .assertRaises (RuntimeError , executorch_module , contiguous_input )
232+
233+ def test_channels_last_input_dim_order_mismatch (self ) -> None :
234+ """Channels-last tensor passed to a contiguous model raises on dim order mismatch."""
235+ model = ModuleChannelsLast ()
236+ inputs = (torch .ones (1 , 2 , 3 , 4 ),)
237+ exported_program = self ._create_program_with_inputs (model , inputs )
238+ executorch_module = self .load_fn (exported_program .buffer )
239+
240+ channels_last_input = inputs [0 ].to (memory_format = torch .channels_last )
241+ self .assertRaises (RuntimeError , executorch_module , channels_last_input )
242+
243+ def test_strided_input_dim_order_error (self ) -> None :
244+ model = ModuleChannelsLast ()
245+ exported_program = self ._create_program_with_inputs (
246+ model , (torch .ones (1 , 2 , 3 , 4 ),)
247+ )
248+ executorch_module = self .load_fn (exported_program .buffer )
249+
250+ self .assertRaises (
251+ RuntimeError , executorch_module , self ._make_non_dense_strided_input ()
252+ )
253+
207254 def test_method_meta (self ) -> None :
208255 exported_program , inputs = create_program (ModuleAdd ())
209256
@@ -466,6 +513,40 @@ def test_method_channels_last_in_default_out(self) -> None:
466513 expected = model (inputs [0 ])
467514 self .assertTrue (torch .allclose (expected , executorch_output ))
468515
516+ def test_method_non_contiguous_input_dim_order_mismatch (self ) -> None :
517+ """Contiguous tensor passed to a channels-last method raises on dim order mismatch."""
518+ model = ModuleChannelsLast ()
519+ exported_program , inputs = create_program (model )
520+ executorch_program = self .load_prog_fn (exported_program .buffer )
521+ executorch_method = executorch_program .load_method ("forward" )
522+
523+ # Model expects channels-last; passing contiguous triggers mismatch.
524+ contiguous_input = inputs [0 ].contiguous ()
525+ self .assertRaises (RuntimeError , executorch_method , contiguous_input )
526+
527+ def test_method_channels_last_input_dim_order_mismatch (self ) -> None :
528+ """Channels-last tensor passed to a contiguous method raises on dim order mismatch."""
529+ model = ModuleChannelsLast ()
530+ inputs = (torch .ones (1 , 2 , 3 , 4 ),)
531+ exported_program = self ._create_program_with_inputs (model , inputs )
532+ executorch_program = self .load_prog_fn (exported_program .buffer )
533+ executorch_method = executorch_program .load_method ("forward" )
534+
535+ channels_last_input = inputs [0 ].to (memory_format = torch .channels_last )
536+ self .assertRaises (RuntimeError , executorch_method , channels_last_input )
537+
538+ def test_method_strided_input_dim_order_error (self ) -> None :
539+ model = ModuleChannelsLast ()
540+ exported_program = self ._create_program_with_inputs (
541+ model , (torch .ones (1 , 2 , 3 , 4 ),)
542+ )
543+ executorch_program = self .load_prog_fn (exported_program .buffer )
544+ executorch_method = executorch_program .load_method ("forward" )
545+
546+ self .assertRaises (
547+ RuntimeError , executorch_method , self ._make_non_dense_strided_input ()
548+ )
549+
469550 def test_method_bad_name (self ) -> None :
470551 exported_program , inputs = create_program (ModuleAdd ())
471552 executorch_program = self .load_prog_fn (exported_program .buffer )
0 commit comments