@@ -338,6 +338,123 @@ def forward(self, x):
338338 torch .allclose (et_outputs , eager_outputs , atol = 1e-02 , rtol = 1e-02 )
339339 )
340340
341+ def test_previously_failing_ops_lower_successfully (self ):
342+ """
343+ Each of these snippets used to crash the CoreML partitioner / lowering
344+ pipeline. They are kept here as regression tests so any future change
345+ that re-breaks one of them surfaces in CI.
346+
347+ - Conv1d with stride>kernel and groups>1 (#11688)
348+ - int32 matmul with a constant weight (#11691)
349+ - BatchNorm3d / InstanceNorm3d on a rank-5 input (#11701, #11702)
350+ - ReflectionPad3d / ReplicationPad3d (#11708, #11709)
351+ """
352+
353+ cases = []
354+
355+ class Conv1dCase (torch .nn .Module ):
356+ def __init__ (self ):
357+ super ().__init__ ()
358+ self .conv = torch .nn .Conv1d (
359+ 16 , 4 , 6 , stride = 8 , padding = 0 , dilation = 2 , groups = 2 , bias = False
360+ )
361+
362+ def forward (self , x ):
363+ return self .conv (x )
364+
365+ cases .append (("issue_11688_conv1d" , Conv1dCase ().eval (), (torch .randn (2 , 16 , 11 ),)))
366+
367+ class Int32MmCase (torch .nn .Module ):
368+ def __init__ (self ):
369+ super ().__init__ ()
370+ self .weight = torch .randint (0 , 100 , (8 , 8 )).to (torch .int32 )
371+
372+ def forward (self , x ):
373+ return torch .mm (x , self .weight )
374+
375+ cases .append (
376+ (
377+ "issue_11691_int32_mm" ,
378+ Int32MmCase ().eval (),
379+ (torch .randn (8 , 8 ).to (torch .int32 ),),
380+ )
381+ )
382+
383+ class BatchNorm3dCase (torch .nn .Module ):
384+ def __init__ (self ):
385+ super ().__init__ ()
386+ self .norm = torch .nn .BatchNorm3d (3 )
387+
388+ def forward (self , x ):
389+ return self .norm (x )
390+
391+ cases .append (
392+ (
393+ "issue_11701_batchnorm3d" ,
394+ BatchNorm3dCase ().eval (),
395+ (torch .randn (1 , 3 , 4 , 4 , 4 ),),
396+ )
397+ )
398+
399+ class InstanceNorm3dCase (torch .nn .Module ):
400+ def __init__ (self ):
401+ super ().__init__ ()
402+ self .norm = torch .nn .InstanceNorm3d (3 )
403+
404+ def forward (self , x ):
405+ return self .norm (x )
406+
407+ cases .append (
408+ (
409+ "issue_11702_instancenorm3d" ,
410+ InstanceNorm3dCase ().eval (),
411+ (torch .randn (1 , 3 , 4 , 4 , 4 ),),
412+ )
413+ )
414+
415+ class ReflectionPad3dCase (torch .nn .Module ):
416+ def __init__ (self ):
417+ super ().__init__ ()
418+ self .pad = torch .nn .ReflectionPad3d (2 )
419+
420+ def forward (self , x ):
421+ return self .pad (x )
422+
423+ cases .append (
424+ (
425+ "issue_11708_reflection_pad3d" ,
426+ ReflectionPad3dCase ().eval (),
427+ (torch .randn (1 , 6 , 6 , 6 , 6 ),),
428+ )
429+ )
430+
431+ class ReplicationPad3dCase (torch .nn .Module ):
432+ def __init__ (self ):
433+ super ().__init__ ()
434+ self .pad = torch .nn .ReplicationPad3d (2 )
435+
436+ def forward (self , x ):
437+ return self .pad (x )
438+
439+ cases .append (
440+ (
441+ "issue_11709_replication_pad3d" ,
442+ ReplicationPad3dCase ().eval (),
443+ (torch .randn (1 , 6 , 6 , 6 , 6 ),),
444+ )
445+ )
446+
447+ for name , model , example_inputs in cases :
448+ with self .subTest (name = name ):
449+ ep = torch .export .export (model , example_inputs , strict = True )
450+ executorch .exir .to_edge_transform_and_lower (
451+ ep ,
452+ partitioner = [CoreMLPartitioner ()],
453+ compile_config = executorch .exir .EdgeCompileConfig (
454+ _check_ir_validity = False
455+ ),
456+ )
457+
341458 def test_deprecation_warning_for_to_backend_workflow (self ):
342459 """
343460 Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
0 commit comments