Skip to content

Commit 810bc80

Browse files
author
RoomWithOutRoof
committed
Add MLX Op Handler for aten.isinf
Implement isinf op handler for the MLX delegate backend. isinf(x) is decomposed as abs(x) == inf, using existing AbsNode and EqualNode which are already supported in the MLX graph schema. The handler also includes a corresponding test case with _inf_input_fn that generates inputs containing both positive and negative infinity. Fixes: #18922
1 parent 490ec5c commit 810bc80

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

backends/mlx/ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,41 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
444444
return out
445445

446446

447+
@REGISTRY.register(target=[torch.ops.aten.isinf.default])
448+
def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot:
449+
"""Handle aten.isinf - check for infinite values element-wise.
450+
451+
isinf(x) is equivalent to abs(x) == inf.
452+
"""
453+
args = P.args(n)
454+
require_args(args, 1, 1, "aten.isinf")
455+
require_kwargs(P.kwargs(n), set(), "aten.isinf")
456+
x = args[0]
457+
458+
# Create abs(x)
459+
_, abs_tmp = P.make_tmp_slot()
460+
P.emit(
461+
AbsNode(
462+
x=P.slot_to_tid(x),
463+
out=P.slot_to_tid(abs_tmp),
464+
)
465+
)
466+
467+
# Create inf constant
468+
inf_slot = emit_lifted_constant(P, float('inf'), torch.float32)
469+
470+
# Compare abs(x) == inf
471+
out = P.make_or_get_slot(n)
472+
P.emit(
473+
EqualNode(
474+
a=P.slot_to_tid(abs_tmp),
475+
b=P.slot_to_tid(inf_slot),
476+
out=P.slot_to_tid(out),
477+
)
478+
)
479+
return out
480+
481+
447482
_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
448483
(
449484
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],

backends/mlx/test/test_ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4020,6 +4020,21 @@ def fn(shape, dtype):
40204020
return fn
40214021

40224022

4023+
def _inf_input_fn():
4024+
"""Return a callable(shape, dtype) that generates inputs with some inf values."""
4025+
4026+
def fn(shape, dtype):
4027+
x = torch.randn(shape, dtype=dtype)
4028+
# Insert some inf values
4029+
mask_pos = torch.rand(shape) > 0.8
4030+
mask_neg = torch.rand(shape) > 0.9
4031+
x[mask_pos] = float('inf')
4032+
x[mask_neg] = float('-inf')
4033+
return (x,)
4034+
4035+
return fn
4036+
4037+
40234038
# Standard shape and dtype configs used by unary tests.
40244039
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
40254040
_SHAPES_2 = [(16,), (4, 4)]
@@ -4112,7 +4127,7 @@ def create_model(self) -> nn.Module:
41124127
{"op_name": "neg", "op_fn": torch.neg},
41134128
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
41144129
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
4115-
# activations
4130+
{"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _inf_input_fn()}, # activations
41164131
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
41174132
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},
41184133
{"op_name": "tanh", "op_fn": torch.tanh, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=3)},

0 commit comments

Comments
 (0)