Skip to content

Commit 47e244c

Browse files
rootroot
authored andcommitted
issue/551 debug context management by adding __globals__ resolve
1 parent 5b7ef9c commit 47e244c

1 file changed

Lines changed: 48 additions & 10 deletions

File tree

python/infinicore/ntops.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ def use_ntops():
77
import ntops
88

99
return _TemporaryAttributes(
10-
(("ntops.torch.torch", infinicore),)
11-
+ tuple(
10+
tuple(
1211
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
1312
for op_name in ntops.torch.__all__
13+
)
14+
+ tuple(
15+
(f"ntops.torch.{op_name}.__globals__['torch']", infinicore)
16+
for op_name in ntops.torch.__all__
1417
)
1518
)
1619

@@ -23,33 +26,68 @@ def __init__(self, attribute_mappings):
2326

2427
def __enter__(self):
2528
for attr_path, new_value in self._attribute_mappings:
26-
parent, attr_name = self._resolve_path(attr_path)
29+
parent, attr_name, is_dict_key = self._resolve_path(attr_path)
2730

2831
try:
29-
self._original_values[attr_path] = getattr(parent, attr_name)
30-
except AttributeError:
32+
if is_dict_key:
33+
self._original_values[attr_path] = parent.__globals__[attr_name]
34+
else:
35+
self._original_values[attr_path] = getattr(parent, attr_name)
36+
except (AttributeError, KeyError):
3137
pass
3238

33-
setattr(parent, attr_name, new_value)
39+
if is_dict_key:
40+
parent.__globals__[attr_name] = new_value
41+
else:
42+
setattr(parent, attr_name, new_value)
3443

3544
return self
3645

3746
def __exit__(self, exc_type, exc_value, traceback):
3847
for attr_path, _ in self._attribute_mappings:
39-
parent, attr_name = self._resolve_path(attr_path)
48+
parent, attr_name, is_dict_key = self._resolve_path(attr_path)
4049

4150
if attr_path in self._original_values:
42-
setattr(parent, attr_name, self._original_values[attr_path])
51+
original_value = self._original_values[attr_path]
52+
if is_dict_key:
53+
parent.__globals__[attr_name] = original_value
54+
else:
55+
setattr(parent, attr_name, original_value)
4356
else:
44-
delattr(parent, attr_name)
57+
if is_dict_key:
58+
if attr_name in parent.__globals__.keys():
59+
del parent.__globals__[attr_name]
60+
else:
61+
if parent is not None and attr_name is not None:
62+
delattr(parent, attr_name)
4563

4664
@staticmethod
4765
def _resolve_path(path):
66+
is_dict_key = False
67+
dict_key_match = None
68+
69+
if path.endswith("']"):
70+
try:
71+
start_index = path.rindex("['")
72+
end_index = path.rindex("']")
73+
74+
if start_index > 0 and end_index == len(path) - 2:
75+
is_dict_key = True
76+
dict_key_match = path[start_index + 2 : end_index]
77+
path = path[:start_index]
78+
except ValueError:
79+
pass
80+
4881
*parent_parts, attr_name = path.split(".")
4982

5083
curr = sys.modules[parent_parts[0]]
5184

5285
for part in parent_parts[1:]:
5386
curr = getattr(curr, part)
87+
88+
parent = curr
5489

55-
return curr, attr_name
90+
if is_dict_key:
91+
return parent, dict_key_match, True
92+
else:
93+
return parent, attr_name, False

0 commit comments

Comments
 (0)