Skip to content

Commit 47f844b

Browse files
committed
fix
1 parent 2556170 commit 47f844b

2 files changed

Lines changed: 53 additions & 10 deletions

File tree

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.9.2
55
+++++
66

7+
* pr:`412`: patches for ViTModel (through rewriting)
8+
79
0.9.1
810
+++++
911

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,39 @@ def visit_Name(self, node):
144144
self.local_variables.add(node.id)
145145
return node
146146

147+
def is_expression_context(self, node):
148+
if not hasattr(node, "_parent") or node._parent is None:
149+
return False
150+
parent = node._parent
151+
# Common expression contexts:
152+
if isinstance(
153+
parent,
154+
(
155+
ast.BinOp,
156+
ast.UnaryOp,
157+
ast.BoolOp,
158+
ast.Call,
159+
ast.Subscript,
160+
ast.Compare,
161+
ast.Return,
162+
ast.Expr,
163+
ast.If,
164+
ast.While,
165+
),
166+
):
167+
return True
168+
# RHS of assignment: parent is Assign and node is in value
169+
if isinstance(parent, ast.Assign) and node in ast.walk(parent.value):
170+
return True
171+
return False
172+
173+
def _attach_parents(self, node, parent=None):
174+
node._parent = parent
175+
if parent and not hasattr(node, "lineno"):
176+
node.lineno = parent.lineno
177+
for child in ast.iter_child_nodes(node):
178+
self._attach_parents(child, node)
179+
147180
def visit_FunctionDef(self, node):
148181
# Capture argument names for branch functions
149182
old_args = self.current_func_args
@@ -152,8 +185,11 @@ def visit_FunctionDef(self, node):
152185
for n in node.body:
153186
visited = self.visit(n)
154187
if isinstance(visited, list):
188+
for n in visited:
189+
self._attach_parents(n, node)
155190
new_body.extend(visited)
156191
else:
192+
self._attach_parents(visited, node)
157193
new_body.append(visited)
158194
node.body = new_body
159195
self.current_func_args = old_args
@@ -455,10 +491,10 @@ def visit_If(self, node):
455491
f"Inconsistencies between n_returned_values={n_returned_values}, "
456492
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
457493
)
458-
tgt = ast.Tuple(tgt_elts, ctx=ast.Store())
494+
tgt = ast.Tuple(list(tgt_elts), ctx=ast.Store())
459495

460496
added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
461-
assign = ast.Assign(targets=[tgt], value=call)
497+
assign = ast.Assign(targets=[tgt], value=call, ctx=ast.Store())
462498
ast.copy_location(assign, node)
463499
ast.fix_missing_locations(assign)
464500
self.local_variables = known_local_variables | added
@@ -631,7 +667,7 @@ def visit_For(self, node):
631667
),
632668
],
633669
decorator_list=[],
634-
ctx=ast.Store(),
670+
ctx=ast.Load(),
635671
)
636672

637673
# final rewriting
@@ -654,7 +690,7 @@ def visit_For(self, node):
654690
args=[
655691
ast.Name(id=func_name, ctx=ast.Load()),
656692
ast.List(
657-
elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store()
693+
elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Load()
658694
),
659695
ast.List(
660696
elts=[
@@ -700,13 +736,13 @@ def visit_For(self, node):
700736
],
701737
*[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars],
702738
],
703-
ctx=ast.Store(),
739+
ctx=ast.Load(),
704740
),
705741
ast.List(
706742
elts=[
707743
ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars]
708744
],
709-
ctx=ast.Store(),
745+
ctx=ast.Load(),
710746
),
711747
],
712748
keywords=[],
@@ -923,23 +959,28 @@ def forward(self, x, y):
923959
)
924960
try:
925961
mod = compile(new_tree, filename="<ast>", mode="exec")
926-
except TypeError as e:
927-
if 'required field "lineno" missing from stmt' in str(e):
962+
except (TypeError, ValueError) as e:
963+
se = str(e)
964+
if (
965+
'required field "lineno" missing from' in se
966+
or "expression must have Load context but has Store instead" in se
967+
):
928968
# Could not find a way to avoid compiling a string.
929969
# The error message still pops up without indicating which node is not
930970
# properly set.
931971
code = ast.unparse(new_tree)
932972
try:
933973
mod = compile(code, filename="<source>", mode="exec")
934974
except IndentationError as ee:
935-
raise RuntimeError(f"Unable to compile\n{code}") from ee
975+
raise RuntimeError(f"Unable to compile due to {ee} (and {e})\n{code}") from ee
936976
else:
937977
kws = dict(include_attributes=True, annotate_fields=True, indent=4)
938978
raise RuntimeError(
939-
f"Unable to compile code\n--CODE--\n"
979+
f"Unable to compile code due to {e}\n--CODE--\n"
940980
f"{ast.unparse(new_tree)}\n--TREE--\n"
941981
f"{ast.dump(new_tree, **kws)}"
942982
) from e
983+
943984
namespace: Dict[str, type] = {}
944985
globs = func.__globals__.copy()
945986
exec(mod, globs, namespace)

0 commit comments

Comments
 (0)