-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvm.py
More file actions
46 lines (40 loc) · 1.38 KB
/
vm.py
File metadata and controls
46 lines (40 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import typing
import torch
from collections import namedtuple
from .converter import *
from .types import *
from mrt.mir.symbol import *
from mrt.mir.mhsymbol import MultiHeadSymbol
from mrt.common.types import *
Executor = namedtuple("Executor", ["vm", "device"])
def create_executor(
symbol: MultiHeadSymbol, params: ParametersT,
device: str = "cpu",
target: str = "",
) -> Executor:
mod = mrt_to_pytorch(symbol, params)
mod.eval()
if not isinstance(device, torch.device):
device = torch.device(device)
return Executor(mod.to(device), device)
def run_executor(
executor: Executor,
data: typing.Optional[np.ndarray] = None,
data_dict: ParametersT = {}) -> OpNumpyT:
(vm, device) = executor
for k, v in data_dict.items():
data_dict[k] = torch.from_numpy(v).to(device)
if data is not None:
data = torch.from_numpy(data).to(device)
with torch.no_grad():
out = vm(data, **data_dict)
# print("run executor:", out.shape)
return data_to_mrt(out)
def infer(graph: MultiHeadSymbol, params: ParametersT,
data: typing.Optional[np.ndarray] = None,
data_dict: ParametersT = {},
device: str = "cpu",
**kwargs):
executor = create_executor(graph, params, device=device, **kwargs)
out = run_executor(executor, data, data_dict)
return out