@@ -367,9 +367,7 @@ def get_paddle_nodes(self, args, kwargs):
367367class ChangePrefixMatcher (BaseMatcher ):
368368 def get_paddle_api (self ):
369369 if self .transformer .mode == "min" :
370- return self .origin_torch_api
371- if self .paddle_api :
372- return self .paddle_api
370+ return self .origin_attr
373371
374372 torch_package = self .torch_api .split ("." , maxsplit = 1 )[0 ]
375373 assert (
@@ -394,7 +392,11 @@ def get_paddle_class_attribute_nodes(self, node):
394392 return "unchange"
395393
396394 def get_paddle_class_nodes (self , func , args , kwargs ):
397- self .parse_func (func )
395+ if self .transformer .mode == "min" :
396+ self .paddle_api = astor .to_source (func ).strip ("\n " )
397+ else :
398+ self .parse_func (func )
399+
398400 args = self .parse_args (args )
399401 kwargs = self .parse_kwargs (kwargs , allow_none = True )
400402
@@ -957,7 +959,7 @@ def get_paddle_nodes(self, args, kwargs):
957959 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
958960 shape = self .parse_args (args )
959961 elif isinstance (args [0 ], ast .Starred ):
960- shape = astor .to_source (args [0 ].value ).replace ("\n " , " " )
962+ shape = astor .to_source (args [0 ].value ).strip ("\n " )
961963 else :
962964 shape = self .parse_args (args )[0 ]
963965 kwargs = {"shape" : str (shape ).replace ("'" , "" ), ** kwargs }
@@ -1015,7 +1017,7 @@ def get_paddle_nodes(self, args, kwargs):
10151017 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
10161018 shape = self .parse_args (args )
10171019 elif isinstance (args [0 ], ast .Starred ):
1018- shape = astor .to_source (args [0 ].value ).replace ("\n " , " " )
1020+ shape = astor .to_source (args [0 ].value ).strip ("\n " )
10191021 else :
10201022 shape = self .parse_args (args )[0 ]
10211023
@@ -1233,9 +1235,9 @@ def get_paddle_nodes(self, args, kwargs):
12331235
12341236 self .enable_utils_code ()
12351237 if paddle_api == "paddle.min" :
1236- self .set_paddle_api ( "paddle_min" )
1238+ self .paddle_api = "paddle_min"
12371239 elif paddle_api == "paddle.max" :
1238- self .set_paddle_api ( "paddle_max" )
1240+ self .paddle_api = "paddle_max"
12391241
12401242 return ChangeAPIMatcher .get_paddle_nodes (self , args , kwargs )
12411243
@@ -1518,7 +1520,7 @@ def get_paddle_nodes(self, args, kwargs):
15181520 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
15191521 shape = self .parse_args (args )
15201522 elif len (args ) == 1 and isinstance (args [0 ], ast .Starred ):
1521- shape = astor .to_source (args [0 ].value ).replace ("\n " , " " )
1523+ shape = astor .to_source (args [0 ].value ).strip ("\n " )
15221524 else :
15231525 if len (args ) == 0 :
15241526 data = []
@@ -1797,7 +1799,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
17971799 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
17981800 perm = self .parse_args (args )
17991801 elif isinstance (args [0 ], ast .Starred ):
1800- perm = astor .to_source (args [0 ].value ).replace ("\n " , " " )
1802+ perm = astor .to_source (args [0 ].value ).strip ("\n " )
18011803 else :
18021804 perm = self .parse_args (args )[0 ]
18031805
@@ -1819,7 +1821,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
18191821 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
18201822 shape = self .parse_args (args )
18211823 elif isinstance (args [0 ], ast .Starred ):
1822- shape = astor .to_source (args [0 ].value ).replace ("\n " , " " )
1824+ shape = astor .to_source (args [0 ].value ).strip ("\n " )
18231825 else :
18241826 shape = self .parse_args (args )[0 ]
18251827
@@ -1971,7 +1973,7 @@ def paddle_split(x, num_or_sections, axis=0):
19711973
19721974 def generate_code (self , kwargs ):
19731975 self .enable_utils_code ()
1974- self .set_paddle_api ( "paddle_split" )
1976+ self .paddle_api = "paddle_split"
19751977 return GenericMatcher .generate_code (self , kwargs )
19761978
19771979
@@ -2041,7 +2043,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
20412043 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
20422044 shape = self .parse_args (args )
20432045 elif isinstance (args [0 ], ast .Starred ):
2044- shape = astor .to_source (args [0 ].value ).replace ("\n " , " " )
2046+ shape = astor .to_source (args [0 ].value ).strip ("\n " )
20452047 else :
20462048 shape = self .parse_args (args )[0 ]
20472049
@@ -2548,7 +2550,7 @@ def get_paddle_nodes(self, args, kwargs):
25482550 if len (args ) == 0 :
25492551 code = "()"
25502552 else :
2551- code = "tuple({})" .format (astor .to_source (args [0 ]).replace ("\n " , " " ))
2553+ code = "tuple({})" .format (astor .to_source (args [0 ]).strip ("\n " ))
25522554
25532555 return ast .parse (code ).body
25542556
@@ -4886,7 +4888,7 @@ def __init__(self, *args, **kwargs):
48864888
48874889 def generate_code (self , kwargs ):
48884890 self .enable_utils_code ()
4889- self .set_paddle_api ( "Embedding" )
4891+ self .paddle_api = "Embedding"
48904892 return GenericMatcher .generate_code (self , kwargs )
48914893
48924894
@@ -4949,7 +4951,7 @@ def get_paddle_nodes(self, args, kwargs):
49494951 x = self .parse_args (args )
49504952 else :
49514953 if isinstance (args [0 ], ast .Starred ):
4952- x = astor .to_source (args [0 ].value ).replace ("\n " , " " )
4954+ x = astor .to_source (args [0 ].value ).strip ("\n " )
49534955 else :
49544956 x = self .parse_args (args )
49554957 kwargs = {dest_var_arg_name : str (x ).replace ("'" , "" )}
@@ -4982,7 +4984,7 @@ def get_paddle_nodes(self, args, kwargs):
49824984 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
49834985 dest_var_arg_value = self .parse_args (args )
49844986 elif len (args ) == 1 and isinstance (args [0 ], ast .Starred ):
4985- dest_var_arg_value = astor .to_source (args [0 ].value ).replace ("\n " , " " )
4987+ dest_var_arg_value = astor .to_source (args [0 ].value ).strip ("\n " )
49864988 else :
49874989 dest_var_arg_value = self .parse_args (args )[0 ]
49884990
@@ -5011,7 +5013,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
50115013 if len (args ) > 1 or (len (args ) == 1 and isinstance (args [0 ], ast .Constant )):
50125014 dest_var_arg_value = self .parse_args (args )
50135015 elif len (args ) == 1 and isinstance (args [0 ], ast .Starred ):
5014- dest_var_arg_value = astor .to_source (args [0 ].value ).replace ("\n " , " " )
5016+ dest_var_arg_value = astor .to_source (args [0 ].value ).strip ("\n " )
50155017 else :
50165018 dest_var_arg_value = self .parse_args (args )[0 ]
50175019
@@ -5498,7 +5500,7 @@ def parse_args_and_kwargs(
54985500
54995501 for i , node in enumerate (args ):
55005502 k = args_list [i ]
5501- v = astor .to_source (node ).replace ("\n " , " " )
5503+ v = astor .to_source (node ).strip ("\n " )
55025504 new_kwargs [k ] = v
55035505
55045506 for node in kwargs :
@@ -5508,7 +5510,7 @@ def parse_args_and_kwargs(
55085510 return None
55095511 continue
55105512
5511- v = astor .to_source (node .value ).replace ("\n " , " " )
5513+ v = astor .to_source (node .value ).strip ("\n " )
55125514 new_kwargs [k ] = v
55135515
55145516 return new_kwargs
@@ -5536,7 +5538,7 @@ def forward(self, inputs, states = None):
55365538
55375539 def generate_code (self , kwargs ):
55385540 self .enable_utils_code ()
5539- self .set_paddle_api ( "GRUCell" )
5541+ self .paddle_api = "GRUCell"
55405542 return GenericMatcher .generate_code (self , kwargs )
55415543
55425544
@@ -5553,7 +5555,7 @@ def forward(self, inputs, states = None):
55535555
55545556 def generate_code (self , kwargs ):
55555557 self .enable_utils_code ()
5556- self .set_paddle_api ( "LSTMCell" )
5558+ self .paddle_api = "LSTMCell"
55575559 return GenericMatcher .generate_code (self , kwargs )
55585560
55595561
@@ -5570,7 +5572,7 @@ def forward(self, inputs, states = None):
55705572
55715573 def generate_code (self , kwargs ):
55725574 self .enable_utils_code ()
5573- self .set_paddle_api ( "SimpleRNNCell" )
5575+ self .paddle_api = "SimpleRNNCell"
55745576 return GenericMatcher .generate_code (self , kwargs )
55755577
55765578
@@ -5981,7 +5983,7 @@ def parse_args_and_kwargs(
59815983
59825984 for i , node in enumerate (args ):
59835985 k = args_list [i ]
5984- v = astor .to_source (node ).replace ("\n " , " " )
5986+ v = astor .to_source (node ).strip ("\n " )
59855987 new_kwargs [k ] = v
59865988
59875989 for node in kwargs :
@@ -5990,6 +5992,6 @@ def parse_args_and_kwargs(
59905992 if not allow_none :
59915993 return None
59925994 continue
5993- v = astor .to_source (node .value ).replace ("\n " , " " )
5995+ v = astor .to_source (node .value ).strip ("\n " )
59945996 new_kwargs [k ] = v
59955997 return new_kwargs
0 commit comments