Skip to content

Commit 107eb92

Browse files
Jah-yeeRoomWithOutRoofmetascroy
authored
Add MLX op handler for aten.isnan (#18952)
Good day ## Summary This PR adds an MLX op handler for `aten.isnan` to the PyTorch ExecuTorch MLX delegate, as requested in issue #18920. ## Changes ### Handler (`backends/mlx/ops.py`) Added `_isnan_handler` registered for `torch.ops.aten.isnan.default`. The implementation uses the mathematical property that `NaN != NaN` — NaN is the only floating-point value that is not equal to itself — to decompose the operation as: ``` isnan(x) = x != x ``` This uses the existing `NotEqualNode` by comparing the input with itself, avoiding the need for a dedicated MLX isnan primitive. ### Test (`backends/mlx/test/test_ops.py`) - Added `_nan_input_fn(nan_frac)` helper that generates tensors with a configurable fraction of NaN values. - Added `isnan` entry to `_UNARY_OP_TESTS` covering `_SHAPES_3` and three float dtypes (`float32`, `float16`, `bfloat16`). ## Testing The test can be run with: ```bash python -m executorch.backends.mlx.test.run_all_tests -k isnan ``` The implementation is consistent with the approach described in the linked issue and follows the existing code patterns in the MLX backend. Thank you for your attention. If there are any issues or suggestions, please leave a comment and I will address them promptly. Warmly, RoomWithOutRoof Co-authored-by: RoomWithOutRoof <roomwithoutroof@users.noreply.github.com> Co-authored-by: Scott Roy <161522778+metascroy@users.noreply.github.com>
1 parent 9576316 commit 107eb92

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

backends/mlx/ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
418418
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))
419419

420420

421+
# ---------------------------------------------------------------------------
422+
# Numerical checks
423+
# ---------------------------------------------------------------------------
424+
425+
426+
@REGISTRY.register(target=[torch.ops.aten.isnan.default])
427+
def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
428+
"""Handle aten.isnan - check for NaN values element-wise.
429+
430+
isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
431+
"""
432+
args = P.args(n)
433+
require_args(args, 1, 1, "aten.isnan")
434+
require_kwargs(P.kwargs(n), set(), "aten.isnan")
435+
x = args[0]
436+
out = P.make_or_get_slot(n)
437+
P.emit(
438+
NotEqualNode(
439+
a=P.slot_to_tid(x),
440+
b=P.slot_to_tid(x),
441+
out=P.slot_to_tid(out),
442+
)
443+
)
444+
return out
445+
446+
421447
_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
422448
(
423449
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],

backends/mlx/test/test_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,6 +4004,22 @@ def fn(shape, dtype):
40044004
return fn
40054005

40064006

4007+
def _nan_input_fn(nan_frac: float = 0.3):
4008+
"""Return a callable(shape, dtype) that generates inputs with some NaN values.
4009+
4010+
Args:
4011+
nan_frac: Fraction of elements to set to NaN (default 0.3 = 30%).
4012+
"""
4013+
4014+
def fn(shape, dtype):
4015+
x = torch.randn(shape, dtype=dtype)
4016+
mask = torch.rand(shape) > (1.0 - nan_frac)
4017+
x[mask] = float("nan")
4018+
return (x,)
4019+
4020+
return fn
4021+
4022+
40074023
# Standard shape and dtype configs used by unary tests.
40084024
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
40094025
_SHAPES_2 = [(16,), (4, 4)]
@@ -4095,6 +4111,7 @@ def create_model(self) -> nn.Module:
40954111
{"op_name": "abs", "op_fn": torch.abs},
40964112
{"op_name": "neg", "op_fn": torch.neg},
40974113
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4114+
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
40984115
# activations
40994116
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
41004117
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},

0 commit comments

Comments
 (0)