Skip to content

Commit a74c68c

Browse files
authored
Add torch op coverage for LLM attention mask construction (apple#2668)
Adds the small set of torch ops that HuggingFace attention-mask code emits via torch.export but coremltools didn't yet handle, exposed while converting google/gemma-4-E2B-it: - Register bitwise_or / bitwise_xor (with `or` / `xor` aliases for the post-sanitize form). The existing bitwise_and was the only registered member of the family; this restores symmetry. Both new handlers reuse the logical_* MIL primitives, matching the existing bitwise_and pattern. - Relax bitwise_and / bitwise_or / bitwise_xor to accept mixed-bool inputs (cast both to bool when at least one is bool). Pure non-bool inputs are still rejected with the same error so genuine integer bitwise math is unchanged. This unblocks Gemma-style mask combination where a bool causal mask meets a float padding mask. - Register aten::new_ones mirroring the existing new_zeros, using _make_fill_op so float-typed shape inputs from torch.export are coerced to int32. - Add where.ScalarOther as an alias on the existing where handler (which already does dtype promotion and broadcasting). - Fix sanitize_op_kind so the `__name__` wrapper is also stripped after the namespace and overload suffix have been removed. Previously aten::__or__.Tensor sanitized to "__or__" instead of "or", making the registry lookup miss even when an "or" handler existed. Tests: - Unit tests for sanitize_op_kind covering the dunder-after-namespace case in test_internal_graph.py. - Op-level tests for new_ones, bitwise_or, bitwise_xor and the `tensor | tensor` operator form in test_torch_ops.py. Validated end-to-end on google/gemma-4-E2B-it: torch.export -> ct.convert -> mlprogram now succeeds and the fp32 model output matches the PyTorch reference (top-5 5/5, per-position argmax 100%, max abs diff 0.05).
1 parent 896bb1c commit a74c68c

4 files changed

Lines changed: 179 additions & 7 deletions

File tree

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5734,19 +5734,45 @@ def bitwise_not(context, node):
57345734
context.add(x)
57355735

57365736

5737-
@register_torch_op(torch_alias=["and"])
5738-
def bitwise_and(context, node):
5737+
def _bitwise_as_logical_if_boolean(context, node, op_name, logical_handler):
5738+
"""Shared body for bitwise_and/or/xor.
5739+
5740+
Core ML has no true bitwise op on integers, so we lower to the logical
5741+
counterpart whenever at least one operand is bool — which covers the common
5742+
"combine boolean masks" pattern in attention/transformer code (where
5743+
torch.export may produce a mixed bool/float pair). Pure non-bool inputs are
5744+
still rejected so we don't silently change semantics for genuine integer
5745+
bitwise math.
5746+
"""
57395747
inputs = _get_inputs(context, node)
5740-
57415748
input_dtypes = [i.dtype for i in inputs]
5742-
if all(types.is_bool(input_dtype) for input_dtype in input_dtypes):
5743-
logical_and(context, node)
5749+
if any(types.is_bool(d) for d in input_dtypes):
5750+
logical_handler(context, node)
57445751
else:
57455752
raise NotImplementedError(
5746-
f"The `bitwise_and` op only supports boolean input, but get {input_dtypes}."
5753+
f"The `{op_name}` op only supports boolean input, but get {input_dtypes}."
57475754
)
57485755

57495756

5757+
@register_torch_op(torch_alias=["and"])
5758+
def bitwise_and(context, node):
5759+
_bitwise_as_logical_if_boolean(context, node, "bitwise_and", logical_and)
5760+
5761+
5762+
# "or" and "xor" cover the post-sanitize form of "aten::__or__" / "aten::__xor__"
5763+
# which torch.export emits for `tensor | tensor` / `tensor ^ tensor`. These are
5764+
# common when building boolean attention masks (e.g. Gemma combines a causal
5765+
# mask with a padding mask via __or__).
5766+
@register_torch_op(torch_alias=["or"])
5767+
def bitwise_or(context, node):
5768+
_bitwise_as_logical_if_boolean(context, node, "bitwise_or", logical_or)
5769+
5770+
5771+
@register_torch_op(torch_alias=["xor"])
5772+
def bitwise_xor(context, node):
5773+
_bitwise_as_logical_if_boolean(context, node, "bitwise_xor", logical_xor)
5774+
5775+
57505776
@register_torch_op
57515777
def logical_not(context, node):
57525778
# There is an optional `out` parameter in torch.logical_not.
@@ -6663,6 +6689,16 @@ def new_zeros(context, node):
66636689
context.add(mb.fill(shape=shape, value=0., name=node.name))
66646690

66656691

6692+
@register_torch_op
6693+
def new_ones(context, node):
6694+
# tensor.new_ones(size) — same shape semantics as new_zeros, value is 1.
6695+
# Use _make_fill_op so float-typed shape inputs (which torch.export sometimes
6696+
# produces) are coerced to int32 automatically.
6697+
inputs = _get_inputs(context, node)
6698+
result = _make_fill_op(inputs[1], 1.0, node.name)
6699+
context.add(result)
6700+
6701+
66666702
@register_torch_op
66676703
def scalar_tensor(context, node):
66686704
x = _get_inputs(context, node, expected=[1, 5])[0]
@@ -7443,7 +7479,7 @@ def _nonzero_as_tuple(context, node, x):
74437479
context.add(result, node.name)
74447480

74457481

7446-
@register_torch_op(torch_alias=["where.self"])
7482+
@register_torch_op(torch_alias=["where.self", "where.scalarother"])
74477483
def where(context, node):
74487484
inputs = _get_inputs(context, node)
74497485

coremltools/converters/mil/frontend/torch/test/test_internal_graph.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,39 @@
2121
from .. import utils
2222
from ..converter import TranscriptionContext
2323
from ..internal_graph import InternalTorchIRNode
24+
from ..utils import sanitize_op_kind
25+
26+
27+
class TestSanitizeOpKind:
28+
"""Unit tests for the op-name canonicalizer used by both TorchScript and EXIR
29+
frontends. The trickiest case is op overloads whose canonical name only
30+
contains a "__name__" wrapper after the namespace prefix is stripped — e.g.
31+
aten::__or__.Tensor must canonicalize to "or" so it resolves against the
32+
same registry entry as the legacy "__or__" form.
33+
"""
34+
35+
@pytest.mark.parametrize(
36+
"raw, expected",
37+
[
38+
# Already-canonical names round-trip.
39+
("add", "add"),
40+
("logical_or", "logical_or"),
41+
# Legacy double-underscore form (single token).
42+
("__add__", "add"),
43+
("__or__", "or"),
44+
# ATen / overload-suffixed forms.
45+
("aten::add.Tensor", "add"),
46+
("aten::bmm.default", "bmm"),
47+
("aten::pow.Tensor_Scalar", "pow"),
48+
# Dunder hidden behind a namespace + overload suffix — the case that
49+
# used to slip through and produce e.g. "__or__" as the lookup key.
50+
("aten::__or__.Tensor", "or"),
51+
("aten::__and__.Tensor", "and"),
52+
("aten::__xor__.Tensor", "xor"),
53+
],
54+
)
55+
def test_sanitize_op_kind(self, raw, expected):
56+
assert sanitize_op_kind(raw) == expected
2457

2558

2659
class TestTorchOps:

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4501,6 +4501,34 @@ def forward(self, x):
45014501
)
45024502

45034503

4504+
class TestNewOnes(TorchBaseTest):
4505+
@pytest.mark.parametrize(
4506+
"compute_unit, backend, frontend, shape",
4507+
itertools.product(
4508+
compute_units,
4509+
backends,
4510+
frontends,
4511+
[
4512+
(1,),
4513+
(2, 3),
4514+
(1, 1, 2, 5, 1),
4515+
],
4516+
),
4517+
)
4518+
def test_new_ones_static(self, compute_unit, backend, frontend, shape):
4519+
class OnesStaticModel(nn.Module):
4520+
def forward(self, x):
4521+
return x.new_ones(x.shape)
4522+
4523+
self.run_compare_torch(
4524+
shape,
4525+
OnesStaticModel().eval(),
4526+
frontend=frontend,
4527+
backend=backend,
4528+
compute_unit=compute_unit,
4529+
)
4530+
4531+
45044532
class TestNewFull(TorchBaseTest):
45054533
@pytest.mark.parametrize(
45064534
"compute_unit, backend, frontend, rank",
@@ -13316,6 +13344,75 @@ def forward(self, x, y):
1331613344
)
1331713345

1331813346

13347+
class TestBitwiseOr(TorchBaseTest):
13348+
@pytest.mark.parametrize(
13349+
"compute_unit, backend, frontend",
13350+
itertools.product(compute_units, backends, frontends),
13351+
)
13352+
def test_bitwise_or(self, compute_unit, backend, frontend):
13353+
class TestModel(torch.nn.Module):
13354+
def forward(self, x, y):
13355+
return torch.bitwise_or(x, y)
13356+
13357+
input_shape = (2, 3)
13358+
input_data_x = torch.rand(*input_shape) > 0.2
13359+
input_data_y = torch.rand(*input_shape) < 0.8
13360+
self.run_compare_torch(
13361+
[input_data_x, input_data_y],
13362+
TestModel(),
13363+
frontend=frontend,
13364+
backend=backend,
13365+
compute_unit=compute_unit,
13366+
input_as_shape=False,
13367+
)
13368+
13369+
@pytest.mark.parametrize(
13370+
"compute_unit, backend, frontend",
13371+
itertools.product(compute_units, backends, frontends),
13372+
)
13373+
def test_or_operator(self, compute_unit, backend, frontend):
13374+
# Exercises tensor.__or__ (i.e. `x | y`) which sanitizes to "or" and
13375+
# is the form torch.export emits when building boolean attention masks.
13376+
class TestModel(torch.nn.Module):
13377+
def forward(self, x, y):
13378+
return x | y
13379+
13380+
input_shape = (2, 3)
13381+
input_data_x = torch.rand(*input_shape) > 0.2
13382+
input_data_y = torch.rand(*input_shape) < 0.8
13383+
self.run_compare_torch(
13384+
[input_data_x, input_data_y],
13385+
TestModel(),
13386+
frontend=frontend,
13387+
backend=backend,
13388+
compute_unit=compute_unit,
13389+
input_as_shape=False,
13390+
)
13391+
13392+
13393+
class TestBitwiseXor(TorchBaseTest):
13394+
@pytest.mark.parametrize(
13395+
"compute_unit, backend, frontend",
13396+
itertools.product(compute_units, backends, frontends),
13397+
)
13398+
def test_bitwise_xor(self, compute_unit, backend, frontend):
13399+
class TestModel(torch.nn.Module):
13400+
def forward(self, x, y):
13401+
return torch.bitwise_xor(x, y)
13402+
13403+
input_shape = (2, 3)
13404+
input_data_x = torch.rand(*input_shape) > 0.2
13405+
input_data_y = torch.rand(*input_shape) < 0.8
13406+
self.run_compare_torch(
13407+
[input_data_x, input_data_y],
13408+
TestModel(),
13409+
frontend=frontend,
13410+
backend=backend,
13411+
compute_unit=compute_unit,
13412+
input_as_shape=False,
13413+
)
13414+
13415+
1331913416
class TestUnfold(TorchBaseTest):
1332013417
@pytest.mark.parametrize(
1332113418
"compute_unit, backend, frontend, input_shape, is_dynamic_hw, kernel_size, dilation, padding, stride",

coremltools/converters/mil/frontend/torch/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ def skip_default_prefix_and_suffix_with_deliminator(
163163
op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, "::")
164164
op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, ".")
165165

166+
# 4. Strip the "__name__" wrapper again. The dunder may only become visible
167+
# after stripping the namespace and overload suffix above, e.g.
168+
# "aten::__or__.Tensor" -> "__or__" -> "or".
169+
if op_kind.startswith("__") and op_kind.endswith("__"):
170+
op_kind = op_kind[2:-2]
171+
166172
return op_kind
167173

168174

0 commit comments

Comments
 (0)