-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[REFACTOR][IR] Inline ReplaceGlobalVars into AttachGlobalSymbol #19625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,15 +24,117 @@ | |
| #include <tvm/ffi/cast.h> | ||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/ir/module.h> | ||
| #include <tvm/ir/replace_global_vars.h> | ||
| #include <tvm/relax/expr_functor.h> | ||
| #include <tvm/relax/struct_info.h> | ||
| #include <tvm/relax/transform.h> | ||
| #include <tvm/tirx/function.h> | ||
| #include <tvm/tirx/stmt_functor.h> | ||
|
|
||
| #include <vector> | ||
|
|
||
| namespace tvm { | ||
| namespace relax { | ||
| namespace transform { | ||
|
|
||
| namespace { | ||
|
|
||
| // File-local mutator: replace GlobalVar references inside a relax::Function. | ||
| struct RelaxGvarMutator : ExprMutator { | ||
| ffi::Map<GlobalVar, GlobalVar> replacements; | ||
| explicit RelaxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements) | ||
| : replacements(replacements) {} | ||
|
|
||
| using ExprMutator::VisitExpr_; | ||
| Expr VisitExpr_(const GlobalVarNode* node) override { | ||
| auto gvar = ffi::GetRef<GlobalVar>(node); | ||
| return replacements.Get(gvar).value_or(gvar); | ||
| } | ||
| }; | ||
|
|
||
| // File-local mutator: replace GlobalVar references inside a tirx::PrimFunc. | ||
| struct TirxGvarMutator : tirx::StmtExprMutator { | ||
| ffi::Map<GlobalVar, GlobalVar> replacements; | ||
| explicit TirxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements) | ||
| : replacements(replacements) {} | ||
|
|
||
| PrimExpr VisitExpr_(const tirx::CallNode* node) override { | ||
| auto call = Downcast<tirx::Call>(tirx::StmtExprMutator::VisitExpr_(node)); | ||
| if (auto old_gvar = call->op.as<GlobalVar>()) { | ||
| if (auto new_gvar = replacements.Get(old_gvar.value())) { | ||
| call.CopyOnWrite()->op = new_gvar.value(); | ||
| } | ||
| } | ||
| return call; | ||
| } | ||
| }; | ||
|
|
||
| // Replace GlobalVar references across all functions in the module. | ||
| // Direct dispatch on function type — no NodeFunctor indirection needed | ||
| // since this file already includes the relax + tirx headers. | ||
| IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map<GlobalVar, GlobalVar> replacements) { | ||
| if (replacements.empty()) { | ||
| return mod; | ||
| } | ||
|
|
||
| std::vector<GlobalVar> to_remove; | ||
| IRModule updates; | ||
|
|
||
| for (const auto& [old_gvar, old_func] : mod->functions) { | ||
| auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar); | ||
| BaseFunc new_func; | ||
|
|
||
| if (auto* prim_func_node = old_func.as<tirx::PrimFuncNode>()) { | ||
| auto func = ffi::GetRef<tirx::PrimFunc>(prim_func_node); | ||
| TirxGvarMutator mutator(replacements); | ||
| auto new_body = mutator(func->body); | ||
| if (!new_body.same_as(func->body)) { | ||
| func.CopyOnWrite()->body = new_body; | ||
| } | ||
| // Update kGlobalSymbol if the function is externally exposed and being renamed. | ||
| if (func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) { | ||
| if (new_gvar->name_hint != old_gvar->name_hint) { | ||
| func = WithAttr(func, tvm::attr::kGlobalSymbol, new_gvar->name_hint); | ||
| } | ||
| } | ||
| new_func = func; | ||
| } else if (auto* relax_func_node = old_func.as<FunctionNode>()) { | ||
| RelaxGvarMutator mutator(replacements); | ||
| auto new_relax_func = | ||
| Downcast<Function>(mutator(Downcast<Function>(ffi::GetRef<Function>(relax_func_node)))); | ||
| // Update kGlobalSymbol if the function is externally exposed and being renamed. | ||
| if (new_relax_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) { | ||
| if (new_gvar->name_hint != old_gvar->name_hint) { | ||
| new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, new_gvar->name_hint); | ||
| } | ||
| } | ||
|
Comment on lines
+104
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, since we have direct access to // Update kGlobalSymbol if the function is externally exposed and being renamed.
if (new_relax_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
if (new_gvar->name_hint != old_gvar->name_hint) {
new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
}
} |
||
| new_func = new_relax_func; | ||
| } else if (old_func.as<ExternFuncNode>()) { | ||
| // ExternFunc: no internal GlobalVar references to update. | ||
| new_func = old_func; | ||
| } else { | ||
| new_func = old_func; | ||
| } | ||
|
|
||
| if (!new_gvar.same_as(old_gvar)) { | ||
| to_remove.push_back(old_gvar); | ||
| } | ||
| if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) { | ||
| updates->Add(new_gvar, new_func); | ||
| } | ||
| } | ||
|
|
||
| if (to_remove.size() || updates->functions.size()) { | ||
| auto write_ptr = mod.CopyOnWrite(); | ||
| for (const auto& old_gvar : to_remove) { | ||
| write_ptr->Remove(old_gvar); | ||
| } | ||
| write_ptr->Update(updates); | ||
| } | ||
| return mod; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Pass AttachGlobalSymbol() { | ||
| auto pass_func = [=](IRModule mod, PassContext pc) { | ||
| ffi::String c_prefix = mod->GetAttr<ffi::String>(tvm::attr::kSystemLibPrefix).value_or(""); | ||
|
|
@@ -74,7 +176,7 @@ Pass AttachGlobalSymbol() { | |
| mod.CopyOnWrite()->Update(updates); | ||
|
|
||
| if (gvar_updates.size()) { | ||
| mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates); | ||
| mod = ReplaceGlobalVarsInModule(mod, gvar_updates); | ||
| } | ||
| } | ||
| return mod; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are already iterating over
mod->functionsand have direct access toold_gvarandnew_gvar, we do not need to perform a linear scan overreplacementsto find the matchingGlobalVar. We can directly check ifnew_gvarhas a different name thanold_gvarand update the attribute accordingly. This simplifies the code and improves the lookup complexity from O(N) to O(1).