Skip to content

Commit e95804f

Browse files
authored
Fix int32 torch.mm runtime by lowering to matmul (#2673)
* Fix int32 torch.mm runtime by lowering to matmul * Restore constant-weight lowering for non-int32 matmul
1 parent fcb59d5 commit e95804f

2 files changed

Lines changed: 54 additions & 1 deletion

File tree

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,12 @@ def pixel_unshuffle(context, node):
10361036

10371037

10381038
def _construct_matmul(x: Var, y: Var, name: Optional[str] = None) -> Var:
1039-
if (len(y.shape) == 2 and len(x.shape) <= 3) and (_is_const(y) or y.is_descendant_of_const):
1039+
if (
1040+
x.dtype != types.int32
1041+
and y.dtype != types.int32
1042+
and (len(y.shape) == 2 and len(x.shape) <= 3)
1043+
and (_is_const(y) or y.is_descendant_of_const)
1044+
):
10401045
linear_x, weight = x, y
10411046
transposed_weight = mb.transpose(x=weight, perm=(1, 0))
10421047
res = mb.linear(x=linear_x, weight=transposed_weight, name=name)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import coremltools as ct
2020
from coremltools import RangeDim, Shape, TensorType
2121
from coremltools._deps import _HAS_TORCH_AUDIO, _HAS_TORCH_VISION, version_lt
22+
from coremltools.converters.mil.backend.mil.load import BlobWriter
2223
from coremltools.converters.mil import testing_reqs
2324
from coremltools.converters.mil.frontend.torch.utils import (
2425
NUM_TO_TORCH_DTYPE,
@@ -41,6 +42,7 @@
4142
ModuleWrapper,
4243
TorchBaseTest,
4344
contains_op,
45+
convert_to_mlmodel,
4446
export_torch_model_to_frontend,
4547
frontends,
4648
generate_input_data,
@@ -7110,6 +7112,52 @@ def test_bmm(self, compute_unit, backend, frontend):
71107112
[shape_x, shape_y], model, compute_unit=compute_unit, backend=backend, frontend=frontend
71117113
)
71127114

7115+
@pytest.mark.parametrize("frontend", frontends)
7116+
@pytest.mark.parametrize(
7117+
"convert_to, minimum_deployment_target",
7118+
[("neuralnetwork", None), ("mlprogram", ct.target.iOS15)],
7119+
)
7120+
def test_mm_with_int32_constant_weight(
7121+
self, frontend, convert_to, minimum_deployment_target
7122+
):
7123+
class TestModel(torch.nn.Module):
7124+
def __init__(self):
7125+
super().__init__()
7126+
self.register_buffer(
7127+
"weight", torch.randint(low=-9, high=9, size=(4, 4), dtype=torch.int32)
7128+
)
7129+
7130+
def forward(self, x):
7131+
return torch.mm(x, self.weight)
7132+
7133+
model = TestModel().eval()
7134+
input_data = torch.randint(low=-9, high=9, size=(4, 4), dtype=torch.int32)
7135+
model_spec = export_torch_model_to_frontend(model, input_data, frontend)
7136+
7137+
if convert_to == "mlprogram" and BlobWriter is None:
7138+
pytest.skip("BlobWriter not loaded")
7139+
7140+
mlmodel = convert_to_mlmodel(
7141+
model_spec,
7142+
[input_data],
7143+
backend=(convert_to, "fp32"),
7144+
converter_input_type=[
7145+
ct.TensorType(name="x", shape=input_data.shape, dtype=np.int32)
7146+
],
7147+
compute_unit=ct.ComputeUnit.CPU_ONLY,
7148+
minimum_deployment_target=minimum_deployment_target,
7149+
)
7150+
7151+
ops = get_op_types_in_program(mlmodel._mil_program)
7152+
assert "matmul" in ops
7153+
assert "linear" not in ops
7154+
7155+
if ct.utils._is_macos() and ct.models.model._MLModelProxy is not None:
7156+
output_name = mlmodel._spec.description.output[0].name
7157+
expected = model(input_data).numpy()
7158+
prediction = mlmodel.predict({"x": input_data.numpy()})[output_name]
7159+
np.testing.assert_array_equal(prediction.astype(np.int32), expected)
7160+
71137161
@pytest.mark.parametrize(
71147162
"compute_unit, backend, frontend",
71157163
itertools.product(

0 commit comments

Comments
 (0)