diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..27d214e0ae9 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -50,6 +50,7 @@ AsStridedNode, AsTypeNode, Atan2Node, + BitwiseInvertNode, BroadcastToNode, CeilNode, ClipNode, @@ -3066,27 +3067,40 @@ def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot: @REGISTRY.register(target=[torch.ops.aten.bitwise_not.default]) def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not.""" + """Handle aten.bitwise_not - logical_not for bool, bitwise_invert for integers.""" args = P.args(n) require_args(args, 1, 1, "aten.bitwise_not") require_kwargs(P.kwargs(n), set(), "aten.bitwise_not") x_meta = n.args[0].meta.get("val") + out = P.make_or_get_slot(n) - if x_meta is not None and x_meta.dtype == torch.bool: - # For boolean tensors, bitwise_not is equivalent to logical_not - out = P.make_or_get_slot(n) + if x_meta is None or not hasattr(x_meta, "dtype"): + raise NotImplementedError( + "aten.bitwise_not requires known input dtype metadata for MLX lowering" + ) + + if x_meta.dtype == torch.bool: P.emit( LogicalNotNode( x=P.slot_to_tid(args[0]), out=P.slot_to_tid(out), ) ) - return out + elif x_meta.dtype in { + torch.int32, + torch.int64, + }: + P.emit( + BitwiseInvertNode( + x=P.slot_to_tid(args[0]), + out=P.slot_to_tid(out), + ) + ) else: raise NotImplementedError( - f"aten.bitwise_not is only supported for boolean tensors. " - f"Got dtype={x_meta.dtype if x_meta else 'unknown'}" + f"aten.bitwise_not on dtype {x_meta.dtype} is not supported for MLX lowering" ) + return out @REGISTRY.register( diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 9fa08ab722d..304fdfe9805 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1380,6 +1380,13 @@ inline void exec_logical_not( st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s)); } +inline void exec_bitwise_invert( + const BitwiseInvertNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, bitwise_invert(st.const_tensor_ref(n.x), s)); +} + inline void exec_logical_and( const LogicalAndNode& n, ExecutionState& st, @@ -2028,6 +2035,10 @@ class Interpreter { case OpCode::LOGICAL_NOT: ops::exec_logical_not(std::get(instr.node), st, s); break; + case OpCode::BITWISE_INVERT: + ops::exec_bitwise_invert( + std::get(instr.node), st, s); + break; case OpCode::LOGICAL_AND: ops::exec_logical_and(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 6e8d6f47db8..67b4636f0be 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -562,6 +562,11 @@ table LogicalNotNode { out: Tid (required); } +table BitwiseInvertNode { + x: Tid (required); + out: Tid (required); +} + table LogicalAndNode { a: Tid (required); b: Tid (required); @@ -1113,7 +1118,8 @@ union OpNode { GatherMmNode, GatherQmmNode, ScanNode, - MetalKernelNode + MetalKernelNode, + BitwiseInvertNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7ba3902e436..459d5aa1e73 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4111,6 +4111,7 @@ def create_model(self) -> nn.Module: {"op_name": "abs", "op_fn": torch.abs}, {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, + {"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn": _int_input_fn()}, {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},