Commit a74c68c
authored
Add torch op coverage for LLM attention mask construction (apple#2668)
Adds the small set of torch ops that HuggingFace attention-mask code
emits via torch.export but coremltools didn't yet handle, exposed
while converting google/gemma-4-E2B-it:
- Register bitwise_or / bitwise_xor (with `or` / `xor` aliases for the
post-sanitize form). The existing bitwise_and was the only registered
member of the family; this restores symmetry. Both new handlers reuse
the logical_* MIL primitives, matching the existing bitwise_and pattern.
- Relax bitwise_and / bitwise_or / bitwise_xor to accept mixed-bool
inputs (cast both to bool when at least one is bool). Pure non-bool
inputs are still rejected with the same error so genuine integer
bitwise math is unchanged. This unblocks Gemma-style mask combination
where a bool causal mask meets a float padding mask.
- Register aten::new_ones mirroring the existing new_zeros, using
_make_fill_op so float-typed shape inputs from torch.export are
coerced to int32.
- Add where.ScalarOther as an alias on the existing where handler
(which already does dtype promotion and broadcasting).
- Fix sanitize_op_kind so the `__name__` wrapper is also stripped after
the namespace and overload suffix have been removed. Previously
aten::__or__.Tensor sanitized to "__or__" instead of "or", making the
registry lookup miss even when an "or" handler existed.
Tests:
- Unit tests for sanitize_op_kind covering the dunder-after-namespace
case in test_internal_graph.py.
- Op-level tests for new_ones, bitwise_or, bitwise_xor and the
`tensor | tensor` operator form in test_torch_ops.py.
Validated end-to-end on google/gemma-4-E2B-it: torch.export ->
ct.convert -> mlprogram now succeeds and the fp32 model output
matches the PyTorch reference (top-5 5/5, per-position argmax 100%,
max abs diff 0.05).1 parent 896bb1c commit a74c68c
4 files changed
Lines changed: 179 additions & 7 deletions
File tree
- coremltools/converters/mil/frontend/torch
- test
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5734 | 5734 | | |
5735 | 5735 | | |
5736 | 5736 | | |
5737 | | - | |
5738 | | - | |
| 5737 | + | |
| 5738 | + | |
| 5739 | + | |
| 5740 | + | |
| 5741 | + | |
| 5742 | + | |
| 5743 | + | |
| 5744 | + | |
| 5745 | + | |
| 5746 | + | |
5739 | 5747 | | |
5740 | | - | |
5741 | 5748 | | |
5742 | | - | |
5743 | | - | |
| 5749 | + | |
| 5750 | + | |
5744 | 5751 | | |
5745 | 5752 | | |
5746 | | - | |
| 5753 | + | |
5747 | 5754 | | |
5748 | 5755 | | |
5749 | 5756 | | |
| 5757 | + | |
| 5758 | + | |
| 5759 | + | |
| 5760 | + | |
| 5761 | + | |
| 5762 | + | |
| 5763 | + | |
| 5764 | + | |
| 5765 | + | |
| 5766 | + | |
| 5767 | + | |
| 5768 | + | |
| 5769 | + | |
| 5770 | + | |
| 5771 | + | |
| 5772 | + | |
| 5773 | + | |
| 5774 | + | |
| 5775 | + | |
5750 | 5776 | | |
5751 | 5777 | | |
5752 | 5778 | | |
| |||
6663 | 6689 | | |
6664 | 6690 | | |
6665 | 6691 | | |
| 6692 | + | |
| 6693 | + | |
| 6694 | + | |
| 6695 | + | |
| 6696 | + | |
| 6697 | + | |
| 6698 | + | |
| 6699 | + | |
| 6700 | + | |
| 6701 | + | |
6666 | 6702 | | |
6667 | 6703 | | |
6668 | 6704 | | |
| |||
7443 | 7479 | | |
7444 | 7480 | | |
7445 | 7481 | | |
7446 | | - | |
| 7482 | + | |
7447 | 7483 | | |
7448 | 7484 | | |
7449 | 7485 | | |
| |||
Lines changed: 33 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
24 | 57 | | |
25 | 58 | | |
26 | 59 | | |
| |||
Lines changed: 97 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4501 | 4501 | | |
4502 | 4502 | | |
4503 | 4503 | | |
| 4504 | + | |
| 4505 | + | |
| 4506 | + | |
| 4507 | + | |
| 4508 | + | |
| 4509 | + | |
| 4510 | + | |
| 4511 | + | |
| 4512 | + | |
| 4513 | + | |
| 4514 | + | |
| 4515 | + | |
| 4516 | + | |
| 4517 | + | |
| 4518 | + | |
| 4519 | + | |
| 4520 | + | |
| 4521 | + | |
| 4522 | + | |
| 4523 | + | |
| 4524 | + | |
| 4525 | + | |
| 4526 | + | |
| 4527 | + | |
| 4528 | + | |
| 4529 | + | |
| 4530 | + | |
| 4531 | + | |
4504 | 4532 | | |
4505 | 4533 | | |
4506 | 4534 | | |
| |||
13316 | 13344 | | |
13317 | 13345 | | |
13318 | 13346 | | |
| 13347 | + | |
| 13348 | + | |
| 13349 | + | |
| 13350 | + | |
| 13351 | + | |
| 13352 | + | |
| 13353 | + | |
| 13354 | + | |
| 13355 | + | |
| 13356 | + | |
| 13357 | + | |
| 13358 | + | |
| 13359 | + | |
| 13360 | + | |
| 13361 | + | |
| 13362 | + | |
| 13363 | + | |
| 13364 | + | |
| 13365 | + | |
| 13366 | + | |
| 13367 | + | |
| 13368 | + | |
| 13369 | + | |
| 13370 | + | |
| 13371 | + | |
| 13372 | + | |
| 13373 | + | |
| 13374 | + | |
| 13375 | + | |
| 13376 | + | |
| 13377 | + | |
| 13378 | + | |
| 13379 | + | |
| 13380 | + | |
| 13381 | + | |
| 13382 | + | |
| 13383 | + | |
| 13384 | + | |
| 13385 | + | |
| 13386 | + | |
| 13387 | + | |
| 13388 | + | |
| 13389 | + | |
| 13390 | + | |
| 13391 | + | |
| 13392 | + | |
| 13393 | + | |
| 13394 | + | |
| 13395 | + | |
| 13396 | + | |
| 13397 | + | |
| 13398 | + | |
| 13399 | + | |
| 13400 | + | |
| 13401 | + | |
| 13402 | + | |
| 13403 | + | |
| 13404 | + | |
| 13405 | + | |
| 13406 | + | |
| 13407 | + | |
| 13408 | + | |
| 13409 | + | |
| 13410 | + | |
| 13411 | + | |
| 13412 | + | |
| 13413 | + | |
| 13414 | + | |
| 13415 | + | |
13319 | 13416 | | |
13320 | 13417 | | |
13321 | 13418 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
163 | 163 | | |
164 | 164 | | |
165 | 165 | | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
166 | 172 | | |
167 | 173 | | |
168 | 174 | | |
| |||
0 commit comments