Skip to content

Commit ca4f1b2

Browse files
committed
[Relax] Fix matmul and reductions with zero-size dimension return uninitialized memory
1 parent aa59644 commit ca4f1b2

2 files changed

Lines changed: 47 additions & 3 deletions

File tree

python/tvm/relax/transform/legalize_ops/linear_algebra.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
4545
b_relax = relax.Var("b", relax.TensorStructInfo(b.shape))
4646
f_infer_sinfo = call.op.get_attr("FInferStructInfo")
4747
output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape
48+
if isinstance(a_shape[-1], trix.IntImm) and a_shape[-1] == 0:
49+
return topi.full(output_shape, call.struct_info.dtype, 0)
4850

4951
def matmul_compute(*idx_spatial):
5052
k = te.reduce_axis((0, a_shape[-1]), name="k")

python/tvm/relax/transform/legalize_ops/statistical.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,54 @@
1717
# pylint: disable=invalid-name
1818
"""Default legalization function for statistical operators."""
1919

20+
from typing import Callable
21+
2022
from tvm import te, tirx, topi
2123

2224
from ...block_builder import BlockBuilder
2325
from ...expr import Call, Expr
2426
from .common import LegalizeFunc, TEFunc, register_legalize
2527

2628

27-
def _statistical(te_func: TEFunc) -> LegalizeFunc:
29+
def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]:
30+
if axis is None:
31+
return list(range(ndim))
32+
33+
axes = []
34+
for dim in axis:
35+
if isinstance(dim, trix.IntImm):
36+
dim = dim.value
37+
dim = int(dim)
38+
axes.append(dim + ndim if dim < 0 else dim)
39+
return axes
40+
41+
42+
def _has_const_zero_reduction_dim(call: Call) -> bool:
43+
input_shape = call.args[0].struct_info.shape
44+
if not isinstance(input_shape, ShapeExpr):
45+
return false
46+
47+
axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values))
48+
return any(isinstance(input_shape.values[dim], tirx.IntImm) and input_shape.values[dim] == 0 for dim in axes)
49+
50+
51+
def _statistical(
52+
te_func: TEFunc,
53+
zero_dim_identity: int | float | bool | Callable[[str], int | float | bool] | None = None,
54+
) -> LegalizeFunc:
2855
def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr:
56+
if zero_dim_identity is not None and _has_const_zero_reduction_dim(call):
57+
fill_value = (
58+
zero_dim_identity(call.struct_info.dtype)
59+
if callable(zero_dim_identity)
60+
else zero_dim_identity
61+
)
62+
return bb.call_te(
63+
topi.full,
64+
call.struct_info.shape.values,
65+
call.struct_info.dtype,
66+
fill_value,
67+
)
2968
return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims)
3069

3170
return statistical_call_te
@@ -129,5 +168,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr:
129168

130169
register_legalize("relax.max", _statistical(topi.max))
131170
register_legalize("relax.min", _statistical(topi.min))
132-
register_legalize("relax.prod", _statistical(topi.prod))
133-
register_legalize("relax.sum", _statistical(topi.sum))
171+
register_legalize(
172+
"relax.prod",
173+
_statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == "bool" else 1),
174+
)
175+
register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0))

0 commit comments

Comments
 (0)