Skip to content

Commit b7d9cae

Browse files
MLX delegate: add integer support for aten.bitwise_not
1 parent e6efe18 commit b7d9cae

4 files changed

Lines changed: 28 additions & 8 deletions

File tree

backends/mlx/ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
AsStridedNode,
5151
AsTypeNode,
5252
Atan2Node,
53+
BitwiseInvertNode,
5354
BroadcastToNode,
5455
CeilNode,
5556
ClipNode,
@@ -3066,27 +3067,28 @@ def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot:
30663067

30673068
@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default])
30683069
def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot:
3069-
"""Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not."""
3070+
"""Handle aten.bitwise_not - logical_not for bool, bitwise_invert for integers."""
30703071
args = P.args(n)
30713072
require_args(args, 1, 1, "aten.bitwise_not")
30723073
require_kwargs(P.kwargs(n), set(), "aten.bitwise_not")
30733074
x_meta = n.args[0].meta.get("val")
3075+
out = P.make_or_get_slot(n)
30743076

30753077
if x_meta is not None and x_meta.dtype == torch.bool:
3076-
# For boolean tensors, bitwise_not is equivalent to logical_not
3077-
out = P.make_or_get_slot(n)
30783078
P.emit(
30793079
LogicalNotNode(
30803080
x=P.slot_to_tid(args[0]),
30813081
out=P.slot_to_tid(out),
30823082
)
30833083
)
3084-
return out
30853084
else:
3086-
raise NotImplementedError(
3087-
f"aten.bitwise_not is only supported for boolean tensors. "
3088-
f"Got dtype={x_meta.dtype if x_meta else 'unknown'}"
3085+
P.emit(
3086+
BitwiseInvertNode(
3087+
x=P.slot_to_tid(args[0]),
3088+
out=P.slot_to_tid(out),
3089+
)
30893090
)
3091+
return out
30903092

30913093

30923094
@REGISTRY.register(

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,13 @@ inline void exec_logical_not(
13801380
st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s));
13811381
}
13821382

1383+
inline void exec_bitwise_invert(
1384+
const BitwiseInvertNode& n,
1385+
ExecutionState& st,
1386+
StreamOrDevice s) {
1387+
st.set_tensor(n.out, bitwise_invert(st.const_tensor_ref(n.x), s));
1388+
}
1389+
13831390
inline void exec_logical_and(
13841391
const LogicalAndNode& n,
13851392
ExecutionState& st,
@@ -2028,6 +2035,10 @@ class Interpreter {
20282035
case OpCode::LOGICAL_NOT:
20292036
ops::exec_logical_not(std::get<LogicalNotNode>(instr.node), st, s);
20302037
break;
2038+
case OpCode::BITWISE_INVERT:
2039+
ops::exec_bitwise_invert(
2040+
std::get<BitwiseInvertNode>(instr.node), st, s);
2041+
break;
20312042
case OpCode::LOGICAL_AND:
20322043
ops::exec_logical_and(std::get<LogicalAndNode>(instr.node), st, s);
20332044
break;

backends/mlx/serialization/schema.fbs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,11 @@ table LogicalNotNode {
562562
out: Tid (required);
563563
}
564564

565+
table BitwiseInvertNode {
566+
x: Tid (required);
567+
out: Tid (required);
568+
}
569+
565570
table LogicalAndNode {
566571
a: Tid (required);
567572
b: Tid (required);
@@ -1113,7 +1118,8 @@ union OpNode {
11131118
GatherMmNode,
11141119
GatherQmmNode,
11151120
ScanNode,
1116-
MetalKernelNode
1121+
MetalKernelNode,
1122+
BitwiseInvertNode
11171123
// BC: Add new op nodes here (append only)
11181124
}
11191125

backends/mlx/test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4111,6 +4111,7 @@ def create_model(self) -> nn.Module:
41114111
{"op_name": "abs", "op_fn": torch.abs},
41124112
{"op_name": "neg", "op_fn": torch.neg},
41134113
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4114+
{"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn": _int_input_fn()},
41144115
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
41154116
# activations
41164117
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},

0 commit comments

Comments
 (0)