Skip to content

Commit 525cf4c

Browse files
ChinChangYangclaude
andcommitted
Add test_mish_stability to verify mish accuracy with fixed weights and inputs
Test uses a Conv2d+Mish+Flatten+Linear model with explicit uniform weights (conv=1.0, bias=0.0) and linspace inputs at three scales (0.1, 3.5, 11.0), producing known mish input intervals (~[-0.9,0.9], ~[-31.5,31.5], ~[-99,99]) to demonstrate stability across large negative and positive values on Neural Engine. This addresses the PR apple#2618 review feedback requesting deterministic test coverage of the large-negative-x regime (x=-30, x=-100). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b411347 commit 525cf4c

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6532,6 +6532,51 @@ def test_mish(self, compute_unit, backend, frontend, shape):
65326532
shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
65336533
)
65346534

6535+
@pytest.mark.parametrize(
6536+
"compute_unit, backend, frontend, scale",
6537+
itertools.product(compute_units, backends, frontends, [0.1, 3.5, 11.0]),
6538+
)
6539+
def test_mish_stability(self, compute_unit, backend, frontend, scale):
6540+
class MishModel(nn.Module):
6541+
def __init__(self):
6542+
super().__init__()
6543+
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
6544+
self.act = nn.Mish()
6545+
self.flatten = nn.Flatten()
6546+
self.fc1 = nn.Linear(28 * 28 * 16, 10)
6547+
6548+
def forward(self, x):
6549+
x = self.act(self.conv1(x))
6550+
x = self.flatten(x)
6551+
x = self.fc1(x)
6552+
return x
6553+
6554+
model = MishModel().eval()
6555+
6556+
# Fixed weights: conv weight=1.0, bias=0.0
6557+
# Each interior conv output pixel = sum of 9 input values ≈ 9 * local_value
6558+
# Mish input interval ≈ [-9*scale, 9*scale]
6559+
# scale=0.1 → mish interval ≈ [-0.9, 0.9] (small values)
6560+
# scale=3.5 → mish interval ≈ [-31.5, 31.5] (covers x=-30 regime)
6561+
# scale=11.0 → mish interval ≈ [-99, 99] (covers x=-100 regime)
6562+
with torch.no_grad():
6563+
model.conv1.weight.fill_(1.0)
6564+
model.conv1.bias.fill_(0.0)
6565+
model.fc1.weight.fill_(0.01)
6566+
model.fc1.bias.fill_(0.0)
6567+
6568+
# Fixed input: 28x28 values from -scale to +scale
6569+
x = torch.linspace(-scale, scale, 28 * 28).reshape(1, 1, 28, 28)
6570+
6571+
TorchBaseTest.run_compare_torch(
6572+
x,
6573+
model,
6574+
input_as_shape=False,
6575+
frontend=frontend,
6576+
backend=backend,
6577+
compute_unit=compute_unit,
6578+
)
6579+
65356580
@pytest.mark.parametrize(
65366581
"compute_unit, backend, frontend, shape",
65376582
itertools.product(compute_units, backends, frontends, COMMON_SHAPES_ALL),

0 commit comments

Comments
 (0)