Skip to content

Commit 9850e2c

Browse files
ChinChangYangclaude
andcommitted
Add test_mish_stability to verify mish accuracy with fixed weights and inputs
Test uses a Conv2d+Mish+Flatten+Linear model with explicit uniform weights (conv=1.0, bias=0.0) and linspace inputs at three scales (0.1, 3.5, 11.0), producing known mish input intervals (~[-0.9,0.9], ~[-31.5,31.5], ~[-99,99]) to demonstrate stability across large negative and positive values on Neural Engine. This addresses the PR apple#2618 review feedback requesting deterministic test coverage of the large-negative-x regime (x=-30, x=-100). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a74c68c commit 9850e2c

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6628,6 +6628,51 @@ def test_mish(self, compute_unit, backend, frontend, shape):
66286628
shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
66296629
)
66306630

6631+
@pytest.mark.parametrize(
6632+
"compute_unit, backend, frontend, scale",
6633+
itertools.product(compute_units, backends, frontends, [0.1, 3.5, 11.0]),
6634+
)
6635+
def test_mish_stability(self, compute_unit, backend, frontend, scale):
6636+
class MishModel(nn.Module):
6637+
def __init__(self):
6638+
super().__init__()
6639+
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
6640+
self.act = nn.Mish()
6641+
self.flatten = nn.Flatten()
6642+
self.fc1 = nn.Linear(28 * 28 * 16, 10)
6643+
6644+
def forward(self, x):
6645+
x = self.act(self.conv1(x))
6646+
x = self.flatten(x)
6647+
x = self.fc1(x)
6648+
return x
6649+
6650+
model = MishModel().eval()
6651+
6652+
# Fixed weights: conv weight=1.0, bias=0.0
6653+
# Each interior conv output pixel = sum of 9 input values ≈ 9 * local_value
6654+
# Mish input interval ≈ [-9*scale, 9*scale]
6655+
# scale=0.1 → mish interval ≈ [-0.9, 0.9] (small values)
6656+
# scale=3.5 → mish interval ≈ [-31.5, 31.5] (covers x=-30 regime)
6657+
# scale=11.0 → mish interval ≈ [-99, 99] (covers x=-100 regime)
6658+
with torch.no_grad():
6659+
model.conv1.weight.fill_(1.0)
6660+
model.conv1.bias.fill_(0.0)
6661+
model.fc1.weight.fill_(0.01)
6662+
model.fc1.bias.fill_(0.0)
6663+
6664+
# Fixed input: 28x28 values from -scale to +scale
6665+
x = torch.linspace(-scale, scale, 28 * 28).reshape(1, 1, 28, 28)
6666+
6667+
TorchBaseTest.run_compare_torch(
6668+
x,
6669+
model,
6670+
input_as_shape=False,
6671+
frontend=frontend,
6672+
backend=backend,
6673+
compute_unit=compute_unit,
6674+
)
6675+
66316676
@pytest.mark.parametrize(
66326677
"compute_unit, backend, frontend, shape",
66336678
itertools.product(compute_units, backends, frontends, COMMON_SHAPES_ALL),

0 commit comments

Comments
 (0)