Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,28 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
};

/*!
* \brief Structural equality that DOES compare tensor (constant) data content. The hash above
* intentionally ignores tensor content for speed, but the equality must stay exact: otherwise two
* grouped functions that differ only in their bound constants (e.g. two conv layers with
* different weights) would be incorrectly treated as duplicates and merged.
*/
class StructuralEqualConsiderNDarray {
public:
bool operator()(const ffi::ObjectRef& lhs, const ffi::ObjectRef& rhs) const {
return ffi::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false,
/*skip_tensor_content=*/false);
}
};

/*!
* \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs
* in context_mod to their GlobalVar to avoid generating duplicated functions.
* We use a custom hash to avoid hashing constants that may be bound to each BaseFunc.
*/
std::unique_ptr<std::unordered_map<
BaseFunc, std::unordered_set<GlobalVar, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>,
StructuralHashIgnoreNDarray, ffi::StructuralEqual>>
StructuralHashIgnoreNDarray, StructuralEqualConsiderNDarray>>
ctx_func_dedup_map_ = nullptr;

/*!
Expand All @@ -446,7 +460,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
if (ctx_func_dedup_map_ != nullptr) return;
ctx_func_dedup_map_ = std::make_unique<std::unordered_map<
BaseFunc, std::unordered_set<GlobalVar, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>,
StructuralHashIgnoreNDarray, ffi::StructuralEqual>>();
StructuralHashIgnoreNDarray, StructuralEqualConsiderNDarray>>();
for (const auto& kv : context_mod_->functions) {
const GlobalVar gv = kv.first;
const BaseFunc func = kv.second;
Expand Down
Loading