Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 0 additions & 57 deletions include/tvm/ir/replace_global_vars.h

This file was deleted.

27 changes: 0 additions & 27 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,33 +195,6 @@ def get_global_vars(self):
"""
return _ffi_api.Module_GetGlobalVars(self)

def replace_global_vars(
self,
replacements: dict[str | _expr.GlobalVar, str | _expr.GlobalVar],
) -> "IRModule":
"""Replace GlobalVar instances within the module

Replace GlobalVars within the IRModule. Since the IRModule
may contain internal references to a GlobalVar, either in TIR
or in Relax, this method should be used whenever replacing or
renaming a GlobalVar.

Parameters
----------
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]]

A dictionary where each key is a GlobalVar to be replaced,
and the corresponding value is the GlobalVar with which to
replace it.

Returns
-------
IRModule
The updated module

"""
return _ffi_api.Module_ReplaceGlobalVars(self, replacements)

@staticmethod
def from_expr(expr, functions=None):
"""Construct a module from a standalone expression.
Expand Down
110 changes: 0 additions & 110 deletions src/ir/replace_global_vars.cc

This file was deleted.

106 changes: 104 additions & 2 deletions src/relax/transform/attach_global_symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,117 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tirx/function.h>
#include <tvm/tirx/stmt_functor.h>

#include <vector>

namespace tvm {
namespace relax {
namespace transform {

namespace {

// File-local mutator: replace GlobalVar references inside a relax::Function.
struct RelaxGvarMutator : ExprMutator {
ffi::Map<GlobalVar, GlobalVar> replacements;
explicit RelaxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements)
: replacements(replacements) {}

using ExprMutator::VisitExpr_;
Expr VisitExpr_(const GlobalVarNode* node) override {
auto gvar = ffi::GetRef<GlobalVar>(node);
return replacements.Get(gvar).value_or(gvar);
}
};

// File-local mutator: replace GlobalVar references inside a tirx::PrimFunc.
struct TirxGvarMutator : tirx::StmtExprMutator {
ffi::Map<GlobalVar, GlobalVar> replacements;
explicit TirxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements)
: replacements(replacements) {}

PrimExpr VisitExpr_(const tirx::CallNode* node) override {
auto call = Downcast<tirx::Call>(tirx::StmtExprMutator::VisitExpr_(node));
if (auto old_gvar = call->op.as<GlobalVar>()) {
if (auto new_gvar = replacements.Get(old_gvar.value())) {
call.CopyOnWrite()->op = new_gvar.value();
}
}
return call;
}
};

// Replace GlobalVar references across all functions in the module.
// Direct dispatch on function type — no NodeFunctor indirection needed
// since this file already includes the relax + tirx headers.
IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map<GlobalVar, GlobalVar> replacements) {
if (replacements.empty()) {
return mod;
}

std::vector<GlobalVar> to_remove;
IRModule updates;

for (const auto& [old_gvar, old_func] : mod->functions) {
auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar);
BaseFunc new_func;

if (auto* prim_func_node = old_func.as<tirx::PrimFuncNode>()) {
auto func = ffi::GetRef<tirx::PrimFunc>(prim_func_node);
TirxGvarMutator mutator(replacements);
auto new_body = mutator(func->body);
if (!new_body.same_as(func->body)) {
func.CopyOnWrite()->body = new_body;
}
// Update kGlobalSymbol if the function is externally exposed and being renamed.
if (func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
if (new_gvar->name_hint != old_gvar->name_hint) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
}
}
Comment on lines +93 to +98
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since we are already iterating over mod->functions and have direct access to old_gvar and new_gvar, we do not need to perform a linear scan over replacements to find the matching GlobalVar. We can directly check if new_gvar has a different name than old_gvar and update the attribute accordingly. This simplifies the code and improves the lookup complexity from O(N) to O(1).

      // Update kGlobalSymbol if the function is externally exposed and being renamed.
      if (func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
        if (new_gvar->name_hint != old_gvar->name_hint) {
          func = WithAttr(func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
        }
      }

new_func = func;
} else if (auto* relax_func_node = old_func.as<FunctionNode>()) {
RelaxGvarMutator mutator(replacements);
auto new_relax_func =
Downcast<Function>(mutator(Downcast<Function>(ffi::GetRef<Function>(relax_func_node))));
// Update kGlobalSymbol if the function is externally exposed and being renamed.
if (new_relax_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
if (new_gvar->name_hint != old_gvar->name_hint) {
new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
}
}
Comment on lines +104 to +109
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, since we have direct access to old_gvar and new_gvar, we can avoid the linear scan over replacements here as well. We can directly check if new_gvar has a different name than old_gvar and update the attribute in O(1) time.

      // Update kGlobalSymbol if the function is externally exposed and being renamed.
      if (new_relax_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
        if (new_gvar->name_hint != old_gvar->name_hint) {
          new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
        }
      }

new_func = new_relax_func;
} else if (old_func.as<ExternFuncNode>()) {
// ExternFunc: no internal GlobalVar references to update.
new_func = old_func;
} else {
new_func = old_func;
}

if (!new_gvar.same_as(old_gvar)) {
to_remove.push_back(old_gvar);
}
if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) {
updates->Add(new_gvar, new_func);
}
}

if (to_remove.size() || updates->functions.size()) {
auto write_ptr = mod.CopyOnWrite();
for (const auto& old_gvar : to_remove) {
write_ptr->Remove(old_gvar);
}
write_ptr->Update(updates);
}
return mod;
}

} // namespace

Pass AttachGlobalSymbol() {
auto pass_func = [=](IRModule mod, PassContext pc) {
ffi::String c_prefix = mod->GetAttr<ffi::String>(tvm::attr::kSystemLibPrefix).value_or("");
Expand Down Expand Up @@ -74,7 +176,7 @@ Pass AttachGlobalSymbol() {
mod.CopyOnWrite()->Update(updates);

if (gvar_updates.size()) {
mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates);
mod = ReplaceGlobalVarsInModule(mod, gvar_updates);
}
}
return mod;
Expand Down
Loading
Loading