Skip to content

Commit ca68bef

Browse files
authored
[Relax][TVMScript] Print ExternFunc struct_info when non-default (#19416)
### Summary 1. Add HasDefaultExternFuncStructInfo helper to detect default FuncStructInfo for extern functions. 2. Update relax::ExternFunc printer to: - emit global_symbol using the correct AccessPath attribute key, - conditionally include struct_info only when it differs from the default inferred-by-sinfo-args derive function, - use a variadic args array instead of a single positional literal to prepare the ExternFunc call. 3. This reduces noisy/redundant output when printing ExternFunc nodes while preserving explicit struct_info when it conveys meaningful information. --------- Co-authored-by: cchung100m <cchung100m@users.noreply.github.com>
1 parent 545c332 commit ca68bef

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/script/printer/relax/function.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ namespace tvm {
2222
namespace script {
2323
namespace printer {
2424

25+
static bool HasDefaultExternFuncStructInfo(const relax::ExternFunc& n) {
26+
const auto* sinfo = n->struct_info_.as<relax::FuncStructInfoNode>();
27+
if (sinfo == nullptr || sinfo->params.defined() || sinfo->purity ||
28+
!sinfo->ret->IsInstance<relax::ObjectStructInfoNode>()) {
29+
return false;
30+
}
31+
return true;
32+
}
33+
2534
bool AtTopLevelFunction(const IRDocsifier& d) {
2635
// fewer than 2 frames: not in a function at all
2736
if (d->frames.size() < 2) {
@@ -128,8 +137,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
128137
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
129138
.set_dispatch<relax::ExternFunc>( //
130139
"", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc {
131-
// TODO(@junrushao): print more information out of extern function.
132-
return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)});
140+
ffi::Array<ExprDoc> args;
141+
args.push_back(LiteralDoc::Str(n->global_symbol, n_p->Attr("global_symbol")));
142+
if (!HasDefaultExternFuncStructInfo(n)) {
143+
args.push_back(d->AsDoc<ExprDoc>(n->struct_info_, n_p->Attr("struct_info_")));
144+
}
145+
return Relax(d, "ExternFunc")->Call(args);
133146
});
134147

135148
TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);

tests/python/relax/test_tvmscript_printer_relax.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,41 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
9898
)
9999

100100

101+
def test_extern_func_with_struct_info():
102+
obj = IRModule(
103+
{
104+
"my_ext": relax.ExternFunc(
105+
"my_ext",
106+
relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True),
107+
),
108+
}
109+
)
110+
_assert_print(
111+
obj,
112+
"""
113+
# from tvm.script import ir as I
114+
# from tvm.script import relax as R
115+
116+
@I.ir_module
117+
class Module:
118+
my_ext = R.ExternFunc("my_ext", R.Callable((), R.Tensor(dtype="float32", ndim=2), True))
119+
""",
120+
)
121+
122+
123+
def test_extern_func_with_struct_info_roundtrip():
124+
mod = IRModule(
125+
{
126+
"my_ext": relax.ExternFunc(
127+
"my_ext",
128+
relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True),
129+
),
130+
}
131+
)
132+
roundtrip = tvm.script.from_source(mod.script(verbose_expr=True))
133+
tvm.ir.assert_structural_equal(mod, roundtrip)
134+
135+
101136
def test_nested_function():
102137
@I.ir_module
103138
class NestedFunction:

0 commit comments

Comments
 (0)