-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopns.py
More file actions
104 lines (83 loc) · 2 KB
/
opns.py
File metadata and controls
104 lines (83 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
""" MRT operator names """
VAR = "var"
DROP_OUT = "nn.dropout"
CONV2D = "nn.conv2d"
DENSE = "nn.dense"
BATCH_NORM = "nn.batch_norm"
# BIAS_ADD = "nn.bias_add"
RELU = "nn.relu"
HARDTANH = "nn.hardtanh"
SILU = "nn.silu"
LEAKY_RELU = "nn.leaky_relu"
ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d"
AVG_POOL2D = "nn.avg_pool2d"
MAX_POOL2D = "nn.max_pool2d"
SOFTMAX = "nn.softmax"
LOG_SOFTMAX = "nn.log_softmax"
EXP = "exp"
SIGMOID = "sigmoid"
SUM = "sum"
MEAN = "mean"
MAX_AXIS = "max"
MAXIMUM = "maximum"
MINIMUM = "minimum"
# =========== NON-CALC ops ===============
TUPLE = "Tuple"
TUPLE_GET_ITEM = "TupleGetItem"
REPEAT = "repeat"
SQUEEZE = "squeeze"
FLATTEN = "flatten"
BATCH_FLATTEN = "nn.batch_flatten"
RESHAPE = "reshape"
CONCAT = "concatenate"
SPLIT = "split"
TRANSPOSE = "transpose"
BROADCAST_TO = "broadcast_to"
EXPAND_DIMS = "expand_dims"
TILE = "tile"
WHERE = "where"
GREATER = "greater"
STRIDED_SLICE = "strided_slice"
SLICE_LIKE = "slice_like"
GET_VALID_COUNT = "vision.get_valid_counts"
NON_MAX_SUPRESSION = "vision.non_max_suppression"
# relax clip attrs from a_min/a_max to min/max
CLIP = "clip"
CEIL = "ceil"
RIGHT_SHIFT = "right_shift"
# relax support astype instead of cast
AS_TYPE = "astype"
# CAST = "cast"
ADV_INDEX = "adv_index"
CALL_TIR = "call_tir"
CALL_DPS_PACKED = "call_dps_packed"
# ======= binary ops =============
ADD = "add"
SUB = "subtract"
MUL = "multiply"
MATMUL = "matmul"
DIV = "divide"
# ======= unary ops ==============
NEGATIVE = "negative"
ABS = "abs"
LOG = "log"
SQRT = "sqrt"
POW = "pow"
PASS = "pass"
# ======= auto generate op =========
ARANGE = "arange"
ZEROS_LIKE = "zeros_like"
ONES_LIKE = "ones_like"
# ======= control flow op ===========
IF = "if"
ARGWHERE = "argwhere"
# ======= mrt requant op ==========
REQUANT = "mrt.requant"
PCLIP = "mrt.pclip"
""" precision clip """
RS_PCLIP = "mrt.rs_pclip"
""" right shift precision clip """
LUT = "mrt.lut"
""" look up table, equals adv_index in tvm """
def Opname2Funcname(op_name: str) -> str:
return op_name.replace('.', '_')