-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
113 lines (93 loc) · 3.18 KB
/
api.py
File metadata and controls
113 lines (93 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import importlib
import sys
from functools import wraps
from mrt.mir.symbol import *
from mrt.mir.mhsymbol import MultiHeadSymbol, Graph
from mrt.common.types import *
from mrt.common.config import MRTConfig
class Singleton(object):
def __new__(cls, *args, **kw):
if not hasattr(cls, '_instance'):
orig = super(Singleton, cls)
cls._instance = orig.__new__(cls, *args, **kw)
return cls._instance
class DynamicModule(Singleton):
def __init__(self):
self._funcs = {}
def load_mod(self, frontend):
# print(frontend)
try:
frontend_module = importlib.import_module(f".{frontend}", package="mrt.frontend")
except ImportError as e:
print(f"Error: Frontend '{frontend}' cannot be imported: {e}")
return
# print("load module:", frontend_module)
for f in self._funcs:
if hasattr(frontend_module, f):
self._funcs[f] = getattr(frontend_module, f)
else:
print(f"Error: function '{f}' not found in frontend '{frontend}'")
# raise ValueError(f)
return self
def typedef_mod_function(self, func):
fname = func.__name__
self._funcs.setdefault(fname, None)
@wraps(func)
def _func_impl(*args, **kwargs):
assert self._funcs[fname] is not None, f"func:{fname} not registered in mod: {self._funcs.keys()}"
# print(f"run {fname}")
func(*args, **kwargs)
return self._funcs[fname](*args, **kwargs)
return _func_impl
mod = DynamicModule()
@mod.typedef_mod_function
def create_executor(
symbol: MultiHeadSymbol, params: ParametersT,
device: str = "cpu",
target: str = "", # no use in pytorch frontend
):
""" Create Runtime Executor for Model Inference. """
pass
@mod.typedef_mod_function
def run_executor(
executor,
data: typing.Optional[np.ndarray] = None,
data_dict: ParametersT = {}
) -> OpNumpyT:
""" Apply data to executor. """
pass
@mod.typedef_mod_function
def infer(
graph: MultiHeadSymbol,
params: ParametersT,
data: typing.Optional[np.ndarray] = None,
data_dict: ParametersT = {},
device: str = "cpu",
**kwargs):
""" Convinent Method to infer model. """
pass
@mod.typedef_mod_function
def data_from_frontend(data: typing.Any) -> OpNumpyT:
""" Convert Frontend Tensor to MRT DType. """
pass
@mod.typedef_mod_function
def data_to_frontend(data: OpNumpyT):
""" Convert MRT DType to Frontend Tensor. """
pass
@mod.typedef_mod_function
def model_from_frontend(
fe_model,
func_names: typing.List[str] = [ "main", ]
) -> typing.Tuple[MultiHeadSymbol, ParametersT]:
""" Convert Frontend Graph to MRT Symbol/Params. """
pass
@mod.typedef_mod_function
def model_to_frontend(graph: MultiHeadSymbol, params: ParametersT,):
""" Convert MRT Symbol/Params to Frontend Graph. """
pass
@mod.typedef_mod_function
def type_infer(symbol: Symbol) -> Symbol:
""" Shape/DType Inference use Frontend API. """
FRONTEND = MRTConfig.G().frontend
mod.load_mod(FRONTEND)