Skip to content

Commit 2189c57

Browse files
authored
add scan unalign api (#868)
1 parent 5752bb8 commit 2189c57

7 files changed

Lines changed: 173 additions & 111 deletions

File tree

paconvert/api_matcher.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,7 @@ def get_paddle_nodes(self, args, kwargs):
367367
class ChangePrefixMatcher(BaseMatcher):
368368
def get_paddle_api(self):
369369
if self.transformer.mode == "min":
370-
return self.origin_torch_api
371-
if self.paddle_api:
372-
return self.paddle_api
370+
return self.origin_attr
373371

374372
torch_package = self.torch_api.split(".", maxsplit=1)[0]
375373
assert (
@@ -394,7 +392,11 @@ def get_paddle_class_attribute_nodes(self, node):
394392
return "unchange"
395393

396394
def get_paddle_class_nodes(self, func, args, kwargs):
397-
self.parse_func(func)
395+
if self.transformer.mode == "min":
396+
self.paddle_api = astor.to_source(func).strip("\n")
397+
else:
398+
self.parse_func(func)
399+
398400
args = self.parse_args(args)
399401
kwargs = self.parse_kwargs(kwargs, allow_none=True)
400402

@@ -957,7 +959,7 @@ def get_paddle_nodes(self, args, kwargs):
957959
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
958960
shape = self.parse_args(args)
959961
elif isinstance(args[0], ast.Starred):
960-
shape = astor.to_source(args[0].value).replace("\n", "")
962+
shape = astor.to_source(args[0].value).strip("\n")
961963
else:
962964
shape = self.parse_args(args)[0]
963965
kwargs = {"shape": str(shape).replace("'", ""), **kwargs}
@@ -1015,7 +1017,7 @@ def get_paddle_nodes(self, args, kwargs):
10151017
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
10161018
shape = self.parse_args(args)
10171019
elif isinstance(args[0], ast.Starred):
1018-
shape = astor.to_source(args[0].value).replace("\n", "")
1020+
shape = astor.to_source(args[0].value).strip("\n")
10191021
else:
10201022
shape = self.parse_args(args)[0]
10211023

@@ -1233,9 +1235,9 @@ def get_paddle_nodes(self, args, kwargs):
12331235

12341236
self.enable_utils_code()
12351237
if paddle_api == "paddle.min":
1236-
self.set_paddle_api("paddle_min")
1238+
self.paddle_api = "paddle_min"
12371239
elif paddle_api == "paddle.max":
1238-
self.set_paddle_api("paddle_max")
1240+
self.paddle_api = "paddle_max"
12391241

12401242
return ChangeAPIMatcher.get_paddle_nodes(self, args, kwargs)
12411243

@@ -1518,7 +1520,7 @@ def get_paddle_nodes(self, args, kwargs):
15181520
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
15191521
shape = self.parse_args(args)
15201522
elif len(args) == 1 and isinstance(args[0], ast.Starred):
1521-
shape = astor.to_source(args[0].value).replace("\n", "")
1523+
shape = astor.to_source(args[0].value).strip("\n")
15221524
else:
15231525
if len(args) == 0:
15241526
data = []
@@ -1797,7 +1799,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
17971799
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
17981800
perm = self.parse_args(args)
17991801
elif isinstance(args[0], ast.Starred):
1800-
perm = astor.to_source(args[0].value).replace("\n", "")
1802+
perm = astor.to_source(args[0].value).strip("\n")
18011803
else:
18021804
perm = self.parse_args(args)[0]
18031805

@@ -1819,7 +1821,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
18191821
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
18201822
shape = self.parse_args(args)
18211823
elif isinstance(args[0], ast.Starred):
1822-
shape = astor.to_source(args[0].value).replace("\n", "")
1824+
shape = astor.to_source(args[0].value).strip("\n")
18231825
else:
18241826
shape = self.parse_args(args)[0]
18251827

@@ -1971,7 +1973,7 @@ def paddle_split(x, num_or_sections, axis=0):
19711973

19721974
def generate_code(self, kwargs):
19731975
self.enable_utils_code()
1974-
self.set_paddle_api("paddle_split")
1976+
self.paddle_api = "paddle_split"
19751977
return GenericMatcher.generate_code(self, kwargs)
19761978

19771979

@@ -2041,7 +2043,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
20412043
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
20422044
shape = self.parse_args(args)
20432045
elif isinstance(args[0], ast.Starred):
2044-
shape = astor.to_source(args[0].value).replace("\n", "")
2046+
shape = astor.to_source(args[0].value).strip("\n")
20452047
else:
20462048
shape = self.parse_args(args)[0]
20472049

@@ -2548,7 +2550,7 @@ def get_paddle_nodes(self, args, kwargs):
25482550
if len(args) == 0:
25492551
code = "()"
25502552
else:
2551-
code = "tuple({})".format(astor.to_source(args[0]).replace("\n", ""))
2553+
code = "tuple({})".format(astor.to_source(args[0]).strip("\n"))
25522554

25532555
return ast.parse(code).body
25542556

@@ -4886,7 +4888,7 @@ def __init__(self, *args, **kwargs):
48864888

48874889
def generate_code(self, kwargs):
48884890
self.enable_utils_code()
4889-
self.set_paddle_api("Embedding")
4891+
self.paddle_api = "Embedding"
48904892
return GenericMatcher.generate_code(self, kwargs)
48914893

48924894

@@ -4949,7 +4951,7 @@ def get_paddle_nodes(self, args, kwargs):
49494951
x = self.parse_args(args)
49504952
else:
49514953
if isinstance(args[0], ast.Starred):
4952-
x = astor.to_source(args[0].value).replace("\n", "")
4954+
x = astor.to_source(args[0].value).strip("\n")
49534955
else:
49544956
x = self.parse_args(args)
49554957
kwargs = {dest_var_arg_name: str(x).replace("'", "")}
@@ -4982,7 +4984,7 @@ def get_paddle_nodes(self, args, kwargs):
49824984
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
49834985
dest_var_arg_value = self.parse_args(args)
49844986
elif len(args) == 1 and isinstance(args[0], ast.Starred):
4985-
dest_var_arg_value = astor.to_source(args[0].value).replace("\n", "")
4987+
dest_var_arg_value = astor.to_source(args[0].value).strip("\n")
49864988
else:
49874989
dest_var_arg_value = self.parse_args(args)[0]
49884990

@@ -5011,7 +5013,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
50115013
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
50125014
dest_var_arg_value = self.parse_args(args)
50135015
elif len(args) == 1 and isinstance(args[0], ast.Starred):
5014-
dest_var_arg_value = astor.to_source(args[0].value).replace("\n", "")
5016+
dest_var_arg_value = astor.to_source(args[0].value).strip("\n")
50155017
else:
50165018
dest_var_arg_value = self.parse_args(args)[0]
50175019

@@ -5498,7 +5500,7 @@ def parse_args_and_kwargs(
54985500

54995501
for i, node in enumerate(args):
55005502
k = args_list[i]
5501-
v = astor.to_source(node).replace("\n", "")
5503+
v = astor.to_source(node).strip("\n")
55025504
new_kwargs[k] = v
55035505

55045506
for node in kwargs:
@@ -5508,7 +5510,7 @@ def parse_args_and_kwargs(
55085510
return None
55095511
continue
55105512

5511-
v = astor.to_source(node.value).replace("\n", "")
5513+
v = astor.to_source(node.value).strip("\n")
55125514
new_kwargs[k] = v
55135515

55145516
return new_kwargs
@@ -5536,7 +5538,7 @@ def forward(self, inputs, states = None):
55365538

55375539
def generate_code(self, kwargs):
55385540
self.enable_utils_code()
5539-
self.set_paddle_api("GRUCell")
5541+
self.paddle_api = "GRUCell"
55405542
return GenericMatcher.generate_code(self, kwargs)
55415543

55425544

@@ -5553,7 +5555,7 @@ def forward(self, inputs, states = None):
55535555

55545556
def generate_code(self, kwargs):
55555557
self.enable_utils_code()
5556-
self.set_paddle_api("LSTMCell")
5558+
self.paddle_api = "LSTMCell"
55575559
return GenericMatcher.generate_code(self, kwargs)
55585560

55595561

@@ -5570,7 +5572,7 @@ def forward(self, inputs, states = None):
55705572

55715573
def generate_code(self, kwargs):
55725574
self.enable_utils_code()
5573-
self.set_paddle_api("SimpleRNNCell")
5575+
self.paddle_api = "SimpleRNNCell"
55745576
return GenericMatcher.generate_code(self, kwargs)
55755577

55765578

@@ -5981,7 +5983,7 @@ def parse_args_and_kwargs(
59815983

59825984
for i, node in enumerate(args):
59835985
k = args_list[i]
5984-
v = astor.to_source(node).replace("\n", "")
5986+
v = astor.to_source(node).strip("\n")
59855987
new_kwargs[k] = v
59865988

59875989
for node in kwargs:
@@ -5990,6 +5992,6 @@ def parse_args_and_kwargs(
59905992
if not allow_none:
59915993
return None
59925994
continue
5993-
v = astor.to_source(node.value).replace("\n", "")
5995+
v = astor.to_source(node.value).strip("\n")
59945996
new_kwargs[k] = v
59955997
return new_kwargs

paconvert/base.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,27 @@ def get_full_attr(self, node):
175175
else:
176176
return "None"
177177

178-
def get_full_attr_for_apiname(self, node):
178+
def get_torch_api_from_node(self, node):
179+
full_attr = self.get_full_attr(node)
180+
attr_list = full_attr.split(".")
181+
old_module = attr_list[0]
182+
if old_module in self.imports_map[self.file]:
183+
new_module = self.imports_map[self.file][old_module]
184+
attr_list[0] = new_module
185+
torch_api = ".".join(attr_list)
186+
return torch_api
187+
else:
188+
return None
189+
190+
def get_full_attr_strict(self, node):
179191
if len(self.imports_map[self.file]["torch_packages"]) == 0:
180192
return "NonTorchClass"
181193
# x.abs() -> 'abs'
182194
if isinstance(node, ast.Attribute):
183195
for item in self.black_list:
184196
if item == node.attr:
185197
return "NonTorchClass"
186-
return self.get_full_attr_for_apiname(node.value) + "." + node.attr
198+
return self.get_full_attr_strict(node.value) + "." + node.attr
187199
# x.abs() -> 'x'
188200
elif isinstance(node, ast.Name):
189201
for item in self.black_list:
@@ -202,7 +214,7 @@ def get_full_attr_for_apiname(self, node):
202214
node,
203215
(ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert),
204216
):
205-
node_str = astor.to_source(node).replace("\n", "")
217+
node_str = astor.to_source(node).strip("\n")
206218
for item in self.black_list:
207219
# (array(1.) + array(2.)).abs() ...
208220
if re.match(".*[^A-Za-z_]{1}%s\(" % item, node_str):
@@ -219,17 +231,16 @@ def get_full_attr_for_apiname(self, node):
219231
else:
220232
return "NonTorchClass"
221233

222-
def get_full_api_from_node(self, node):
223-
full_attr = self.get_full_attr_for_apiname(node)
234+
def replace_torch_module(self, origin_attr):
235+
full_attr = origin_attr
236+
224237
attr_list = full_attr.split(".")
225238
old_module = attr_list[0]
226239
if old_module in self.imports_map[self.file]:
227240
new_module = self.imports_map[self.file][old_module]
228241
attr_list[0] = new_module
229-
torch_api = ".".join(attr_list)
230-
return torch_api, full_attr
231-
else:
232-
return full_attr, None
242+
full_attr = ".".join(attr_list)
243+
return full_attr, origin_attr
233244

234245
def get_canonical_torch_api(self, torch_api):
235246
return GlobalManager.ALIAS_MAPPING.get(torch_api, torch_api)
@@ -315,12 +326,10 @@ def visit_Module(self, node):
315326

316327

317328
class BaseMatcher(object):
318-
def __init__(
319-
self, transformer, torch_api, origin_torch_api, api_mapping_dict, logger
320-
):
329+
def __init__(self, transformer, torch_api, origin_attr, api_mapping_dict, logger):
321330
self.transformer = transformer
322331
self.torch_api = torch_api
323-
self.origin_torch_api = origin_torch_api
332+
self.origin_attr = origin_attr
324333
self.paddle_api = None
325334
self.api_mapping_dict = api_mapping_dict
326335
self.logger = logger
@@ -373,7 +382,7 @@ def parse_args_and_kwargs(
373382
# not support some API args
374383
if k in unsupport_args:
375384
return None
376-
v = astor.to_source(node).replace("\n", "")
385+
v = astor.to_source(node).strip("\n")
377386
# v = ast.unparse(node)
378387
new_kwargs[k] = v
379388

@@ -385,7 +394,7 @@ def parse_args_and_kwargs(
385394
f"Parameter '{k}' specified multiple times - cannot be both positional and keyword argument",
386395
self.transformer.file_name,
387396
)
388-
v = astor.to_source(node.value).replace("\n", "")
397+
v = astor.to_source(node.value).strip("\n")
389398
# v = ast.unparse(node.value)
390399
new_kwargs[k] = v
391400

@@ -396,7 +405,7 @@ def parse_args(self, args):
396405
for node in args:
397406
# if isinstance(node, ast.Starred) and not allow_starred:
398407
# return None
399-
ele = astor.to_source(node).replace("\n", "")
408+
ele = astor.to_source(node).strip("\n")
400409
new_args.append(ele)
401410

402411
return new_args
@@ -413,19 +422,15 @@ def parse_kwargs(self, kwargs, allow_none=False):
413422
# not support some API args
414423
if k in unsupport_args:
415424
return None
416-
v = astor.to_source(node.value).replace("\n", "")
425+
v = astor.to_source(node.value).strip("\n")
417426
new_kwargs[k] = v
418427

419428
return new_kwargs
420429

421430
def parse_func(self, func):
422-
new_func = astor.to_source(func).replace("\n", "")
423-
self.paddleClass = new_func[0 : new_func.rfind(".")]
431+
func_str = astor.to_source(func).strip("\n")
432+
self.paddleClass = func_str[0 : func_str.rfind(".")]
424433
class_str = "paddle.Tensor|paddle.nn.Module|paddle.optimizer.Optimizer|paddle.distribution.Distribution|paddle.autograd.function.FunctionCtx|paddle.profiler.Profiler"
425-
if self.transformer.mode == "min":
426-
class_str += (
427-
"|torch.Tensor|torch.nn.Module|torch.autograd.function.FunctionCtx"
428-
)
429434
if self.get_paddle_api():
430435
new_paddle_api = re.sub(
431436
class_str,
@@ -434,9 +439,7 @@ def parse_func(self, func):
434439
)
435440
# reverse escape
436441
new_paddle_api = re.sub(r"\\(.)", r"\1", new_paddle_api)
437-
self.set_paddle_api(new_paddle_api)
438-
439-
return new_func
442+
self.paddle_api = new_paddle_api
440443

441444
def args_to_str(self, args):
442445
str_list = []
@@ -532,24 +535,21 @@ def enable_utils_code(self):
532535
utils_file_helper.add_code(utils_code)
533536
log_debug(self.logger, "add 'import utils'", self.transformer.file_name)
534537

535-
def set_paddle_api(self, paddle_api):
536-
self.paddle_api = paddle_api
537-
538538
def get_paddle_api(self):
539-
paddle_api = None
539+
ret = None
540540
if self.paddle_api:
541-
paddle_api = self.paddle_api
541+
ret = self.paddle_api
542542
elif "paddle_api" in self.api_mapping_dict:
543-
paddle_api = self.api_mapping_dict["paddle_api"]
543+
ret = self.api_mapping_dict["paddle_api"]
544544
if (
545-
paddle_api
545+
ret
546546
and self.api_mapping_dict.get("abstract")
547547
and self.generate_utils_code() is not None
548548
):
549549
self.enable_utils_code()
550550
if self.api_mapping_dict.get("enable_utils_code"):
551551
self.enable_utils_code()
552-
return paddle_api
552+
return ret
553553

554554
def get_paddle_class_attribute_nodes(self, node):
555555
self.parse_func(node)

0 commit comments

Comments
 (0)