Skip to content

Commit 441eca1

Browse files
committed
Add MLX op handler for aten.isinf
Add decomposed handler for aten.isinf using AbsNode + EqualNode with an inf constant. This enables isinf to run on Metal GPU via MLX instead of falling back to CPU execution. Also adds isinf to the unary op test suite with standard test inputs.
1 parent 9ca0ff1 commit 441eca1

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

backends/mlx/ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,46 @@ def normalize_reduction_dim(
397397
]
398398

399399

400+
@REGISTRY.register(target=[torch.ops.aten.isinf.default])
401+
def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot:
402+
"""Handle aten.isinf.default - check for infinite values element-wise.
403+
404+
isinf(x) is equivalent to abs(x) == inf.
405+
This decomposition avoids a dedicated isinf MLX op by using AbsNode + EqualNode.
406+
"""
407+
args = P.args(n)
408+
require_args(args, 1, 1, "aten.isinf")
409+
require_kwargs(P.kwargs(n), set(), "aten.isinf")
410+
x = args[0]
411+
412+
# Get input dtype for the constant
413+
x_meta = n.args[0].meta.get("val")
414+
dtype = x_meta.dtype if x_meta is not None else torch.float32
415+
416+
# Create abs(x) using a temporary slot
417+
_, abs_tmp = P.make_tmp_slot()
418+
P.emit(
419+
AbsNode(
420+
x=P.slot_to_tid(x),
421+
out=P.slot_to_tid(abs_tmp),
422+
)
423+
)
424+
425+
# Create inf constant
426+
inf_slot = emit_lifted_constant(P, float("inf"), dtype)
427+
428+
# Compare abs(x) == inf
429+
out = P.make_or_get_slot(n)
430+
P.emit(
431+
EqualNode(
432+
a=P.slot_to_tid(abs_tmp),
433+
b=P.slot_to_tid(inf_slot),
434+
out=P.slot_to_tid(out),
435+
)
436+
)
437+
return out
438+
439+
400440
def _make_unary_handler(node_cls: Any, op_name: str):
401441
"""Create a handler for a simple unary op: x → node_cls(x, out)."""
402442

backends/mlx/test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4095,6 +4095,7 @@ def create_model(self) -> nn.Module:
40954095
{"op_name": "abs", "op_fn": torch.abs},
40964096
{"op_name": "neg", "op_fn": torch.neg},
40974097
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4098+
{"op_name": "isinf", "op_fn": torch.isinf, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.float32], "input_fn": _input_fn(uniform=True)}, # isinf uses custom inputs via model wrapper
40984099
# activations
40994100
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
41004101
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},

0 commit comments

Comments
 (0)