Skip to content

Commit b55b95c

Browse files
author
RoomWithOutRoof
committed
Add MLX op handler for aten.isnan
Add a handler for aten.isnan in the MLX delegate using the mathematical property that NaN != NaN (NaN is the only value not equal to itself). This uses the existing NotEqualNode by comparing x with itself. Also add corresponding test with a custom input function that injects NaN values at a configurable fraction.
1 parent a489707 commit b55b95c

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

backends/mlx/ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
418418
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))
419419

420420

421+
# ---------------------------------------------------------------------------
422+
# Numerical checks
423+
# ---------------------------------------------------------------------------
424+
425+
426+
@REGISTRY.register(target=[torch.ops.aten.isnan.default])
427+
def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
428+
"""Handle aten.isnan - check for NaN values element-wise.
429+
430+
isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
431+
"""
432+
args = P.args(n)
433+
require_args(args, 1, 1, "aten.isnan")
434+
require_kwargs(P.kwargs(n), set(), "aten.isnan")
435+
x = args[0]
436+
out = P.make_or_get_slot(n)
437+
P.emit(
438+
NotEqualNode(
439+
a=P.slot_to_tid(x),
440+
b=P.slot_to_tid(x),
441+
out=P.slot_to_tid(out),
442+
)
443+
)
444+
return out
445+
446+
421447
_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
422448
(
423449
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],

backends/mlx/test/test_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,6 +4004,22 @@ def fn(shape, dtype):
40044004
return fn
40054005

40064006

4007+
def _nan_input_fn(nan_frac: float = 0.3):
4008+
"""Return a callable(shape, dtype) that generates inputs with some NaN values.
4009+
4010+
Args:
4011+
nan_frac: Fraction of elements to set to NaN (default 0.3 = 30%).
4012+
"""
4013+
4014+
def fn(shape, dtype):
4015+
x = torch.randn(shape, dtype=dtype)
4016+
mask = torch.rand(shape) > (1.0 - nan_frac)
4017+
x[mask] = float("nan")
4018+
return (x,)
4019+
4020+
return fn
4021+
4022+
40074023
# Standard shape and dtype configs used by unary tests.
40084024
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
40094025
_SHAPES_2 = [(16,), (4, 4)]
@@ -4095,6 +4111,7 @@ def create_model(self) -> nn.Module:
40954111
{"op_name": "abs", "op_fn": torch.abs},
40964112
{"op_name": "neg", "op_fn": torch.neg},
40974113
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4114+
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
40984115
# activations
40994116
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
41004117
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},

0 commit comments

Comments
 (0)