Skip to content

Commit 8855b5f

Browse files
committed
加入 infinicore.use_ntops
1 parent 0c188f2 commit 8855b5f

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

python/infinicore/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
short,
2525
uint8,
2626
)
27+
from infinicore.ntops import use_ntops
2728
from infinicore.ops.matmul import matmul
2829
from infinicore.ops.rearrange import rearrange
2930
from infinicore.tensor import (
@@ -62,6 +63,8 @@
6263
"long",
6364
"short",
6465
"uint8",
66+
# `ntops` integration.
67+
"use_ntops",
6568
# Operations.
6669
"matmul",
6770
"rearrange",

python/infinicore/ntops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import sys
2+
3+
import infinicore
4+
5+
6+
def use_ntops():
7+
import ntops
8+
9+
return _TemporaryAttributes(
10+
(("ntops.torch.torch", infinicore),)
11+
+ tuple(
12+
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
13+
for op_name in ntops.torch.__all__
14+
)
15+
)
16+
17+
18+
class _TemporaryAttributes:
19+
def __init__(self, attribute_mappings):
20+
self._attribute_mappings = attribute_mappings
21+
22+
self._original_values = {}
23+
24+
def __enter__(self):
25+
for attr_path, new_value in self._attribute_mappings:
26+
parent, attr_name = self._resolve_path(attr_path)
27+
28+
try:
29+
self._original_values[attr_path] = getattr(parent, attr_name)
30+
except AttributeError:
31+
pass
32+
33+
setattr(parent, attr_name, new_value)
34+
35+
return self
36+
37+
def __exit__(self, exc_type, exc_value, traceback):
38+
for attr_path, _ in self._attribute_mappings:
39+
parent, attr_name = self._resolve_path(attr_path)
40+
41+
if attr_path in self._original_values:
42+
setattr(parent, attr_name, self._original_values[attr_path])
43+
else:
44+
delattr(parent, attr_name)
45+
46+
@staticmethod
47+
def _resolve_path(path):
48+
*parent_parts, attr_name = path.split(".")
49+
50+
curr = sys.modules[parent_parts[0]]
51+
52+
for part in parent_parts[1:]:
53+
curr = getattr(curr, part)
54+
55+
return curr, attr_name

0 commit comments

Comments
 (0)