|
48 | 48 | TanHOpr, |
49 | 49 | TransposeOpr, |
50 | 50 | TrueDivOpr, |
| 51 | + _ConvOpr, |
51 | 52 | _PoolOpr, |
52 | 53 | ) |
53 | 54 | from .ir_tensor import AxisOrder, IRTensor |
@@ -121,6 +122,7 @@ class TransformerRule(Enum): |
121 | 122 | TRANSPOSE_LINEAR_WEIGHT_TO_NHWC = 133 |
122 | 123 | # force fc with no trans for megengine |
123 | 124 | FC_NO_TRANS = 134 |
| 125 | + BIAS_ASTYPE_INT64 = 135 |
124 | 126 |
|
125 | 127 |
|
126 | 128 | def cmp_rules(a, b): |
@@ -1537,3 +1539,29 @@ def trans_tensor(tensor): |
1537 | 1539 | if opr.transpose_b and tensor_b.owner_opr == None: |
1538 | 1540 | opr.transpose_b = False |
1539 | 1541 | trans_tensor(tensor_b) |
| 1542 | + |
| 1543 | + |
| 1544 | +@_register_tranformation_rule(TransformerRule.BIAS_ASTYPE_INT64) |
| 1545 | +def _bias_astype_int64(net: IRGraph): |
| 1546 | + for opr in net.all_oprs: |
| 1547 | + if not isinstance(opr, (MatMulOpr, _ConvOpr)): |
| 1548 | + continue |
| 1549 | + bias = None |
| 1550 | + if isinstance(opr, MatMulOpr) and len(opr.inp_tensors) == 3: |
| 1551 | + bias = opr.inp_tensors[2] |
| 1552 | + elif isinstance(opr, Deconv2dOpr) and len(opr.inp_tensors) > 3: |
| 1553 | + if ( |
| 1554 | + opr.inp_tensors[0].shape == [4] and len(opr.inp_tensors) == 4 |
| 1555 | + ): # shape as input |
| 1556 | + bias = opr.inp_tensors[-1] |
| 1557 | + if len(opr.inp_tensors) == 4: |
| 1558 | + bias = opr.inp_tensors[-1] |
| 1559 | + elif isinstance(opr, (Conv2dOpr, ConvRelu2dOpr)) and len(opr.inp_tensors) == 3: |
| 1560 | + bias = opr.inp_tensors[-1] |
| 1561 | + if bias is not None and bias.scale is not None: |
| 1562 | + bias.set_qparams( |
| 1563 | + scale=bias.scale, |
| 1564 | + zero_point=bias.zero_point, |
| 1565 | + q_dtype="int64", |
| 1566 | + np_dtype="int64", |
| 1567 | + ) |
0 commit comments