99
1010import torch
1111from executorch .backends .arm ._passes import ArmPass
12+ from executorch .backends .arm ._passes .convert_full_like_to_full_pass import (
13+ ConvertFullLikeToFullPass ,
14+ )
1215from executorch .exir .dialects ._ops import ops as exir_ops
1316from executorch .exir .pass_base import ExportPass
1417
@@ -20,14 +23,14 @@ def _get_leaky_relu_ops(op) -> tuple:
2023 if op in edge_ops :
2124 return (
2225 exir_ops .edge .aten .clamp .default ,
23- exir_ops .edge .aten .full .default ,
26+ exir_ops .edge .aten .full_like .default ,
2427 exir_ops .edge .aten .mul .Tensor ,
2528 exir_ops .edge .aten .add .Tensor ,
2629 )
2730 elif op in torch_ops :
2831 return (
2932 torch .ops .aten .clamp .default ,
30- torch .ops .aten .full .default ,
33+ torch .ops .aten .full_like .default ,
3134 torch .ops .aten .mul .Tensor ,
3235 torch .ops .aten .add .Tensor ,
3336 )
@@ -42,33 +45,31 @@ class DecomposeLeakyReLUPass(ArmPass):
4245 Example:
4346 %op1 = clamp(x,0,None) (equivalent to max(0,x))
4447 %op2 = clamp(x,None,0) (equivalent to min(0,x))
45- %op3 = full(x.shape ,slope)
48+ %op3 = full_like(x ,slope)
4649 %op4 = mul(%op3,%op2)
4750 %op5 = add(%op1,%op4)
4851
4952 """
5053
51- _passes_required_after : Set [Type [ExportPass ]] = set ()
54+ _passes_required_after : Set [Type [ExportPass ]] = { ConvertFullLikeToFullPass }
5255
5356 def call_operator (self , op , args , kwargs , meta ):
5457 if op not in (edge_ops + torch_ops ) or not self .allowed_to_transform (meta ):
5558 return super ().call_operator (op , args , kwargs , meta )
5659
5760 x = args [0 ]
5861 slope = args [1 ] if len (args ) > 1 else 0.01
59- dtype = x .node .meta ["val" ].dtype
60- device = x .node .meta ["val" ].device
61- clamp , full , mul , add = _get_leaky_relu_ops (op )
62+ clamp , full_like , mul , add = _get_leaky_relu_ops (op )
6263 op1 = super ().call_operator (
6364 op = clamp , args = (x , 0 , None ), kwargs = kwargs , meta = meta
6465 )
6566 op2 = super ().call_operator (
6667 op = clamp , args = (x , None , 0 ), kwargs = kwargs , meta = meta
6768 )
6869 op3 = super ().call_operator (
69- op = full ,
70- args = (x . node . meta [ "val" ]. shape , slope ),
71- kwargs = {"dtype" : dtype , "device" : device },
70+ op = full_like ,
71+ args = (x , slope ),
72+ kwargs = {},
7273 meta = meta ,
7374 )
7475 op4 = super ().call_operator (op = mul , args = (op3 , op2 ), kwargs = kwargs , meta = meta )
0 commit comments