Skip to content

Commit 4433d1f

Browse files
authored
[MLX] Add leaky_relu op handler (#20305)
Summary: - Add MLX lowering for aten.leaky_relu.default using existing GreaterEqual, Multiply, and Where nodes. - Add focused MLX op tests for custom negative_slope values, including a slope above 1. Test Plan: - python -m py_compile backends/mlx/ops.py backends/mlx/test/test_ops.py - git diff --check HEAD^..HEAD - PATH="$PWD/.venv-mlx/bin:$PATH" .venv-mlx/bin/lintrunner backends/mlx/ops.py backends/mlx/test/test_ops.py - .venv-mlx/bin/python -m executorch.backends.mlx.test.run_all_tests leaky_relu --timeout 180 cc @metascroy
1 parent 1b726b2 commit 4433d1f

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

backends/mlx/ops.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@
163163
from executorch.exir.dialects._ops import ops as exir_ops
164164
from torch.fx.node import Node
165165

166+
_LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE = 0.01
167+
166168

167169
def require_static_int(value: Any, param_name: str, op_name: str) -> None:
168170
"""
@@ -2786,6 +2788,63 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
27862788
return out
27872789

27882790

2791+
@REGISTRY.register(target=[torch.ops.aten.leaky_relu.default])
2792+
def _leaky_relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
2793+
"""Handle aten.leaky_relu.default - leaky rectified linear unit.
2794+
2795+
leaky_relu(x) = x if x >= 0
2796+
= slope * x otherwise
2797+
2798+
Implemented as where(x >= 0, x, slope * x) so it stays correct for any
2799+
negative_slope (including values > 1), matching eager PyTorch.
2800+
"""
2801+
args = P.args(n)
2802+
require_args(args, 1, 2, "aten.leaky_relu")
2803+
require_kwargs(P.kwargs(n), set(), "aten.leaky_relu")
2804+
2805+
x = args[0]
2806+
negative_slope = _LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE
2807+
if len(args) > 1 and args[1] is not None:
2808+
negative_slope = float(args[1])
2809+
2810+
x_meta = n.args[0].meta.get("val")
2811+
if x_meta is None:
2812+
raise ValueError("Input tensor metadata not found for leaky_relu")
2813+
dtype = x_meta.dtype
2814+
2815+
zero_slot = emit_lifted_constant(P, 0.0, dtype)
2816+
slope_slot = emit_lifted_constant(P, negative_slope, dtype)
2817+
2818+
_, cond_slot = P.make_tmp_slot()
2819+
P.emit(
2820+
GreaterEqualNode(
2821+
a=P.slot_to_tid(x),
2822+
b=P.slot_to_tid(zero_slot),
2823+
out=P.slot_to_tid(cond_slot),
2824+
)
2825+
)
2826+
2827+
_, scaled_slot = P.make_tmp_slot()
2828+
P.emit(
2829+
MultiplyNode(
2830+
a=P.slot_to_tid(slope_slot),
2831+
b=P.slot_to_tid(x),
2832+
out=P.slot_to_tid(scaled_slot),
2833+
)
2834+
)
2835+
2836+
out = P.make_or_get_slot(n)
2837+
P.emit(
2838+
WhereNode(
2839+
condition=P.slot_to_tid(cond_slot),
2840+
x=P.slot_to_tid(x),
2841+
y=P.slot_to_tid(scaled_slot),
2842+
out=P.slot_to_tid(out),
2843+
)
2844+
)
2845+
return out
2846+
2847+
27892848
@REGISTRY.register(target=[torch.ops.aten._log_softmax.default])
27902849
def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot:
27912850
"""Handle aten._log_softmax.default - log of softmax.

backends/mlx/test/test_ops.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,60 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
405405
return (x,)
406406

407407

408+
class LeakyReLUModel(nn.Module):
409+
"""Model that applies leaky_relu with an optional negative slope."""
410+
411+
def __init__(self, negative_slope: Optional[float] = 0.01):
412+
super().__init__()
413+
self.negative_slope = negative_slope
414+
415+
def forward(self, x: torch.Tensor) -> torch.Tensor:
416+
if self.negative_slope is None:
417+
return torch.nn.functional.leaky_relu(x)
418+
return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope)
419+
420+
421+
@register_test
422+
class LeakyReLUTest(OpTestCase):
423+
"""Test case for leaky_relu activation with various negative slopes."""
424+
425+
name = "leaky_relu"
426+
rtol = 1e-5
427+
atol = 1e-5
428+
429+
def __init__(
430+
self,
431+
shape: Tuple[int, ...] = (2, 3, 4),
432+
negative_slope: Optional[float] = 0.01,
433+
):
434+
self.shape = shape
435+
self.negative_slope = negative_slope
436+
shape_str = "x".join(str(s) for s in shape)
437+
slope_str = "default" if negative_slope is None else f"slope{negative_slope}"
438+
self.name = f"leaky_relu_{slope_str}_{shape_str}"
439+
440+
@classmethod
441+
def get_test_configs(cls) -> List["LeakyReLUTest"]:
442+
return [
443+
cls(shape=(2, 3, 4), negative_slope=0.01),
444+
cls(shape=(2, 3, 4), negative_slope=None),
445+
cls(shape=(4, 8), negative_slope=0.1),
446+
cls(shape=(10,), negative_slope=0.2),
447+
cls(shape=(10,), negative_slope=1.5),
448+
cls(shape=(2, 8, 16), negative_slope=0.01),
449+
]
450+
451+
def create_model(self) -> nn.Module:
452+
return LeakyReLUModel(self.negative_slope)
453+
454+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
455+
numel = 1
456+
for size in self.shape:
457+
numel *= size
458+
x = torch.linspace(-4.0, 4.0, steps=numel).reshape(self.shape)
459+
return (x,)
460+
461+
408462
class GELUModel(nn.Module):
409463
"""Simple model using GELU activation."""
410464

0 commit comments

Comments
 (0)