|
17 | 17 | # pylint: disable=invalid-name |
18 | 18 | """Default legalization function for statistical operators.""" |
19 | 19 |
|
| 20 | +from typing import Callable |
| 21 | + |
20 | 22 | from tvm import te, tirx, topi |
21 | 23 |
|
22 | 24 | from ...block_builder import BlockBuilder |
23 | 25 | from ...expr import Call, Expr |
24 | 26 | from .common import LegalizeFunc, TEFunc, register_legalize |
25 | 27 |
|
26 | 28 |
|
27 | | -def _statistical(te_func: TEFunc) -> LegalizeFunc: |
| 29 | +def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]: |
| 30 | + if axis is None: |
| 31 | + return list(range(ndim)) |
| 32 | + |
| 33 | + axes = [] |
| 34 | + for dim in axis: |
| 35 | + if isinstance(dim, trix.IntImm): |
| 36 | + dim = dim.value |
| 37 | + dim = int(dim) |
| 38 | + axes.append(dim + ndim if dim < 0 else dim) |
| 39 | + return axes |
| 40 | + |
| 41 | + |
| 42 | +def _has_const_zero_reduction_dim(call: Call) -> bool: |
| 43 | + input_shape = call.args[0].struct_info.shape |
| 44 | + if not isinstance(input_shape, ShapeExpr): |
| 45 | + return false |
| 46 | + |
| 47 | + axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values)) |
| 48 | + return any(isinstance(input_shape.values[dim], tirx.IntImm) and input_shape.values[dim] == 0 for dim in axes) |
| 49 | + |
| 50 | + |
| 51 | +def _statistical( |
| 52 | + te_func: TEFunc, |
| 53 | + zero_dim_identity: int | float | bool | Callable[[str], int | float | bool] | None = None, |
| 54 | +) -> LegalizeFunc: |
28 | 55 | def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: |
| 56 | + if zero_dim_identity is not None and _has_const_zero_reduction_dim(call): |
| 57 | + fill_value = ( |
| 58 | + zero_dim_identity(call.struct_info.dtype) |
| 59 | + if callable(zero_dim_identity) |
| 60 | + else zero_dim_identity |
| 61 | + ) |
| 62 | + return bb.call_te( |
| 63 | + topi.full, |
| 64 | + call.struct_info.shape.values, |
| 65 | + call.struct_info.dtype, |
| 66 | + fill_value, |
| 67 | + ) |
29 | 68 | return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) |
30 | 69 |
|
31 | 70 | return statistical_call_te |
@@ -129,5 +168,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr: |
129 | 168 |
|
130 | 169 | register_legalize("relax.max", _statistical(topi.max)) |
131 | 170 | register_legalize("relax.min", _statistical(topi.min)) |
132 | | -register_legalize("relax.prod", _statistical(topi.prod)) |
133 | | -register_legalize("relax.sum", _statistical(topi.sum)) |
| 171 | +register_legalize( |
| 172 | + "relax.prod", |
| 173 | + _statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == "bool" else 1), |
| 174 | +) |
| 175 | +register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0)) |
0 commit comments