Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions python/infinicore/ntops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ def use_ntops():
import ntops

return _TemporaryAttributes(
(("ntops.torch.torch", infinicore),)
+ tuple(
tuple(
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
for op_name in ntops.torch.__all__
)
+ tuple(
(f"ntops.torch.{op_name}.__globals__['torch']", infinicore)
for op_name in ntops.torch.__all__
)
)

Expand All @@ -23,33 +26,68 @@ def __init__(self, attribute_mappings):

def __enter__(self):
for attr_path, new_value in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
parent, attr_name, is_dict_key = self._resolve_path(attr_path)

try:
self._original_values[attr_path] = getattr(parent, attr_name)
except AttributeError:
if is_dict_key:
self._original_values[attr_path] = parent.__globals__[attr_name]
else:
self._original_values[attr_path] = getattr(parent, attr_name)
except (AttributeError, KeyError):
pass

setattr(parent, attr_name, new_value)
if is_dict_key:
parent.__globals__[attr_name] = new_value
else:
setattr(parent, attr_name, new_value)

return self

def __exit__(self, exc_type, exc_value, traceback):
for attr_path, _ in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
parent, attr_name, is_dict_key = self._resolve_path(attr_path)

if attr_path in self._original_values:
setattr(parent, attr_name, self._original_values[attr_path])
original_value = self._original_values[attr_path]
if is_dict_key:
parent.__globals__[attr_name] = original_value
else:
setattr(parent, attr_name, original_value)
else:
delattr(parent, attr_name)
if is_dict_key:
if attr_name in parent.__globals__.keys():
del parent.__globals__[attr_name]
else:
if parent is not None and attr_name is not None:
delattr(parent, attr_name)

@staticmethod
def _resolve_path(path):
is_dict_key = False
dict_key_match = None

if path.endswith("']"):
try:
start_index = path.rindex("['")
end_index = path.rindex("']")

if start_index > 0 and end_index == len(path) - 2:
is_dict_key = True
dict_key_match = path[start_index + 2 : end_index]
path = path[:start_index]
except ValueError:
pass

*parent_parts, attr_name = path.split(".")

curr = sys.modules[parent_parts[0]]

for part in parent_parts[1:]:
curr = getattr(curr, part)

parent = curr

return curr, attr_name
if is_dict_key:
return parent, dict_key_match, True
else:
return parent, attr_name, False