@@ -7,10 +7,13 @@ def use_ntops():
77 import ntops
88
99 return _TemporaryAttributes (
10- (("ntops.torch.torch" , infinicore ),)
11- + tuple (
10+ tuple (
1211 (f"infinicore.{ op_name } " , getattr (ntops .torch , op_name ))
1312 for op_name in ntops .torch .__all__
13+ )
14+ + tuple (
15+ (f"ntops.torch.{ op_name } .__globals__['torch']" , infinicore )
16+ for op_name in ntops .torch .__all__
1417 )
1518 )
1619
@@ -23,33 +26,68 @@ def __init__(self, attribute_mappings):
2326
2427 def __enter__ (self ):
2528 for attr_path , new_value in self ._attribute_mappings :
26- parent , attr_name = self ._resolve_path (attr_path )
29+ parent , attr_name , is_dict_key = self ._resolve_path (attr_path )
2730
2831 try :
29- self ._original_values [attr_path ] = getattr (parent , attr_name )
30- except AttributeError :
32+ if is_dict_key :
33+ self ._original_values [attr_path ] = parent .__globals__ [attr_name ]
34+ else :
35+ self ._original_values [attr_path ] = getattr (parent , attr_name )
36+ except (AttributeError , KeyError ):
3137 pass
3238
33- setattr (parent , attr_name , new_value )
39+ if is_dict_key :
40+ parent .__globals__ [attr_name ] = new_value
41+ else :
42+ setattr (parent , attr_name , new_value )
3443
3544 return self
3645
3746 def __exit__ (self , exc_type , exc_value , traceback ):
3847 for attr_path , _ in self ._attribute_mappings :
39- parent , attr_name = self ._resolve_path (attr_path )
48+ parent , attr_name , is_dict_key = self ._resolve_path (attr_path )
4049
4150 if attr_path in self ._original_values :
42- setattr (parent , attr_name , self ._original_values [attr_path ])
51+ original_value = self ._original_values [attr_path ]
52+ if is_dict_key :
53+ parent .__globals__ [attr_name ] = original_value
54+ else :
55+ setattr (parent , attr_name , original_value )
4356 else :
44- delattr (parent , attr_name )
57+ if is_dict_key :
58+ if attr_name in parent .__globals__ .keys ():
59+ del parent .__globals__ [attr_name ]
60+ else :
61+ if parent is not None and attr_name is not None :
62+ delattr (parent , attr_name )
4563
4664 @staticmethod
4765 def _resolve_path (path ):
66+ is_dict_key = False
67+ dict_key_match = None
68+
69+ if path .endswith ("']" ):
70+ try :
71+ start_index = path .rindex ("['" )
72+ end_index = path .rindex ("']" )
73+
74+ if start_index > 0 and end_index == len (path ) - 2 :
75+ is_dict_key = True
76+ dict_key_match = path [start_index + 2 : end_index ]
77+ path = path [:start_index ]
78+ except ValueError :
79+ pass
80+
4881 * parent_parts , attr_name = path .split ("." )
4982
5083 curr = sys .modules [parent_parts [0 ]]
5184
5285 for part in parent_parts [1 :]:
5386 curr = getattr (curr , part )
87+
88+ parent = curr
5489
55- return curr , attr_name
90+ if is_dict_key :
91+ return parent , dict_key_match , True
92+ else :
93+ return parent , attr_name , False
0 commit comments