Skip to content

Commit 8949a40

Browse files
nanookclawclaude
andcommitted
Add MLX op handler for aten.bitwise_xor (#18927)
Add BitwiseXorNode to the MLX delegate, enabling element-wise bitwise XOR for boolean and integer tensors via mlx::core::bitwise_xor. Closes #18927 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2f339f0 commit 8949a40

10 files changed

Lines changed: 9425 additions & 1 deletion

File tree

backends/mlx/_generated_inspector.py

Lines changed: 929 additions & 0 deletions
Large diffs are not rendered by default.

backends/mlx/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
AsStridedNode,
5151
AsTypeNode,
5252
Atan2Node,
53+
BitwiseXorNode,
5354
BroadcastToNode,
5455
CeilNode,
5556
ClipNode,
@@ -490,6 +491,12 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
490491
"aten.ne",
491492
True,
492493
),
494+
(
495+
[torch.ops.aten.bitwise_xor.Tensor, torch.ops.aten.bitwise_xor.Scalar],
496+
BitwiseXorNode,
497+
"aten.bitwise_xor",
498+
True,
499+
),
493500
]
494501

495502

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,13 @@ exec_logical_or(const LogicalOrNode& n, ExecutionState& st, StreamOrDevice s) {
13951395
n.out, logical_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
13961396
}
13971397

1398+
inline void exec_bitwise_xor(
1399+
const BitwiseXorNode& n, ExecutionState& st, StreamOrDevice s) {
1400+
st.set_tensor(
1401+
n.out,
1402+
bitwise_xor(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s));
1403+
}
1404+
13981405
inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) {
13991406
int rows = resolve_int(n.n, st);
14001407
int cols = resolve_int(n.m, st);
@@ -2227,6 +2234,10 @@ class Interpreter {
22272234
case OpCode::METAL_KERNEL:
22282235
ops::exec_metal_kernel(std::get<MetalKernelNode>(instr.node), st, s);
22292236
break;
2237+
case OpCode::BITWISE_XOR:
2238+
ops::exec_bitwise_xor(
2239+
std::get<BitwiseXorNode>(instr.node), st, s);
2240+
break;
22302241
default:
22312242
throw std::runtime_error(
22322243
"Unknown opcode: " + std::to_string(static_cast<int>(instr.op)));

0 commit comments

Comments
 (0)