Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,123 @@ def forward(self, x):
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)

def test_previously_failing_ops_lower_successfully(self):
"""
Each of these snippets used to crash the CoreML partitioner / lowering
pipeline. They are kept here as regression tests so any future change
that re-breaks one of them surfaces in CI.

- Conv1d with stride>kernel and groups>1 (#11688)
- int32 matmul with a constant weight (#11691)
- BatchNorm3d / InstanceNorm3d on a rank-5 input (#11701, #11702)
- ReflectionPad3d / ReplicationPad3d (#11708, #11709)
"""

cases = []

class Conv1dCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(
16, 4, 6, stride=8, padding=0, dilation=2, groups=2, bias=False
)

def forward(self, x):
return self.conv(x)

cases.append(("issue_11688_conv1d", Conv1dCase().eval(), (torch.randn(2, 16, 11),)))

class Int32MmCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randint(0, 100, (8, 8)).to(torch.int32)

def forward(self, x):
return torch.mm(x, self.weight)

cases.append(
(
"issue_11691_int32_mm",
Int32MmCase().eval(),
(torch.randn(8, 8).to(torch.int32),),
)
)

class BatchNorm3dCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.BatchNorm3d(3)

def forward(self, x):
return self.norm(x)

cases.append(
(
"issue_11701_batchnorm3d",
BatchNorm3dCase().eval(),
(torch.randn(1, 3, 4, 4, 4),),
)
)

class InstanceNorm3dCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.InstanceNorm3d(3)

def forward(self, x):
return self.norm(x)

cases.append(
(
"issue_11702_instancenorm3d",
InstanceNorm3dCase().eval(),
(torch.randn(1, 3, 4, 4, 4),),
)
)

class ReflectionPad3dCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch.nn.ReflectionPad3d(2)

def forward(self, x):
return self.pad(x)

cases.append(
(
"issue_11708_reflection_pad3d",
ReflectionPad3dCase().eval(),
(torch.randn(1, 6, 6, 6, 6),),
)
)

class ReplicationPad3dCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch.nn.ReplicationPad3d(2)

def forward(self, x):
return self.pad(x)

cases.append(
(
"issue_11709_replication_pad3d",
ReplicationPad3dCase().eval(),
(torch.randn(1, 6, 6, 6, 6),),
)
)

for name, model, example_inputs in cases:
with self.subTest(name=name):
ep = torch.export.export(model, example_inputs, strict=True)
executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[CoreMLPartitioner()],
compile_config=executorch.exir.EdgeCompileConfig(
_check_ir_validity=False
),
)

def test_deprecation_warning_for_to_backend_workflow(self):
"""
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
Expand Down
Loading