xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (revision 2ce655cf1b029481b88b48b409d7423472856b38)
1b12bcf3fSFrederik Gossen //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2b12bcf3fSFrederik Gossen //
3b12bcf3fSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b12bcf3fSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information.
5b12bcf3fSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b12bcf3fSFrederik Gossen //
7b12bcf3fSFrederik Gossen //===----------------------------------------------------------------------===//
8b12bcf3fSFrederik Gossen 
9b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/IR/FuncOps.h"
10b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/Transforms/Passes.h"
11b12bcf3fSFrederik Gossen 
12b12bcf3fSFrederik Gossen namespace mlir {
13b12bcf3fSFrederik Gossen namespace {
14b12bcf3fSFrederik Gossen 
15b12bcf3fSFrederik Gossen #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
17b12bcf3fSFrederik Gossen 
18b12bcf3fSFrederik Gossen // Define a notion of function equivalence that allows for reuse. Ignore the
19b12bcf3fSFrederik Gossen // symbol name for this purpose.
20b12bcf3fSFrederik Gossen struct DuplicateFuncOpEquivalenceInfo
21b12bcf3fSFrederik Gossen     : public llvm::DenseMapInfo<func::FuncOp> {
22b12bcf3fSFrederik Gossen 
23b12bcf3fSFrederik Gossen   static unsigned getHashValue(const func::FuncOp cFunc) {
24b12bcf3fSFrederik Gossen     if (!cFunc) {
25b12bcf3fSFrederik Gossen       return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
26b12bcf3fSFrederik Gossen     }
27b12bcf3fSFrederik Gossen 
28b12bcf3fSFrederik Gossen     // Aggregate attributes, ignoring the symbol name.
29b12bcf3fSFrederik Gossen     llvm::hash_code hash = {};
30b12bcf3fSFrederik Gossen     func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
31b12bcf3fSFrederik Gossen     StringAttr symNameAttrName = func.getSymNameAttrName();
32b12bcf3fSFrederik Gossen     for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33b12bcf3fSFrederik Gossen       StringAttr attrName = namedAttr.getName();
34b12bcf3fSFrederik Gossen       if (attrName == symNameAttrName)
35b12bcf3fSFrederik Gossen         continue;
36b12bcf3fSFrederik Gossen       hash = llvm::hash_combine(hash, namedAttr);
37b12bcf3fSFrederik Gossen     }
38b12bcf3fSFrederik Gossen 
39b12bcf3fSFrederik Gossen     // Also hash the func body.
40b12bcf3fSFrederik Gossen     func.getBody().walk([&](Operation *op) {
41b12bcf3fSFrederik Gossen       hash = llvm::hash_combine(
42b12bcf3fSFrederik Gossen           hash, OperationEquivalence::computeHash(
43b12bcf3fSFrederik Gossen                     op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
44b12bcf3fSFrederik Gossen                     /*hashResults=*/OperationEquivalence::ignoreHashValue,
45b12bcf3fSFrederik Gossen                     OperationEquivalence::IgnoreLocations));
46b12bcf3fSFrederik Gossen     });
47b12bcf3fSFrederik Gossen 
48b12bcf3fSFrederik Gossen     return hash;
49b12bcf3fSFrederik Gossen   }
50b12bcf3fSFrederik Gossen 
5127b73922SMehdi Amini   static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
5227b73922SMehdi Amini     if (lhs == rhs)
53b12bcf3fSFrederik Gossen       return true;
5427b73922SMehdi Amini     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
5527b73922SMehdi Amini         rhs == getTombstoneKey() || rhs == getEmptyKey())
56b12bcf3fSFrederik Gossen       return false;
57*2ce655cfSLongsheng Mou 
58*2ce655cfSLongsheng Mou     if (lhs.isDeclaration() || rhs.isDeclaration())
59*2ce655cfSLongsheng Mou       return false;
60*2ce655cfSLongsheng Mou 
6127b73922SMehdi Amini     // Check discardable attributes equivalence
6227b73922SMehdi Amini     if (lhs->getDiscardableAttrDictionary() !=
6327b73922SMehdi Amini         rhs->getDiscardableAttrDictionary())
6427b73922SMehdi Amini       return false;
65b12bcf3fSFrederik Gossen 
6627b73922SMehdi Amini     // Check properties equivalence, ignoring the symbol name.
6727b73922SMehdi Amini     // Make a copy, so that we can erase the symbol name and perform the
6827b73922SMehdi Amini     // comparison.
6927b73922SMehdi Amini     auto pLhs = lhs.getProperties();
7027b73922SMehdi Amini     auto pRhs = rhs.getProperties();
7127b73922SMehdi Amini     pLhs.sym_name = nullptr;
7227b73922SMehdi Amini     pRhs.sym_name = nullptr;
7327b73922SMehdi Amini     if (pLhs != pRhs)
74b12bcf3fSFrederik Gossen       return false;
75b12bcf3fSFrederik Gossen 
76b12bcf3fSFrederik Gossen     // Compare inner workings.
77b12bcf3fSFrederik Gossen     return OperationEquivalence::isRegionEquivalentTo(
78b12bcf3fSFrederik Gossen         &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
79b12bcf3fSFrederik Gossen   }
80b12bcf3fSFrederik Gossen };
81b12bcf3fSFrederik Gossen 
82b12bcf3fSFrederik Gossen struct DuplicateFunctionEliminationPass
83b12bcf3fSFrederik Gossen     : public impl::DuplicateFunctionEliminationPassBase<
84b12bcf3fSFrederik Gossen           DuplicateFunctionEliminationPass> {
85b12bcf3fSFrederik Gossen 
86b12bcf3fSFrederik Gossen   using DuplicateFunctionEliminationPassBase<
87b12bcf3fSFrederik Gossen       DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
88b12bcf3fSFrederik Gossen 
89b12bcf3fSFrederik Gossen   void runOnOperation() override {
90b12bcf3fSFrederik Gossen     auto module = getOperation();
91b12bcf3fSFrederik Gossen 
92b12bcf3fSFrederik Gossen     // Find unique representant per equivalent func ops.
93b12bcf3fSFrederik Gossen     DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
94b12bcf3fSFrederik Gossen     DenseMap<StringAttr, func::FuncOp> getRepresentant;
95b12bcf3fSFrederik Gossen     DenseSet<func::FuncOp> toBeErased;
96b12bcf3fSFrederik Gossen     module.walk([&](func::FuncOp f) {
97b12bcf3fSFrederik Gossen       auto [repr, inserted] = uniqueFuncOps.insert(f);
98b12bcf3fSFrederik Gossen       getRepresentant[f.getSymNameAttr()] = *repr;
99b12bcf3fSFrederik Gossen       if (!inserted) {
100b12bcf3fSFrederik Gossen         toBeErased.insert(f);
101b12bcf3fSFrederik Gossen       }
102b12bcf3fSFrederik Gossen     });
103b12bcf3fSFrederik Gossen 
104*2ce655cfSLongsheng Mou     // Update all symbol uses to reference unique func op
105*2ce655cfSLongsheng Mou     // representants and erase redundant func ops.
106*2ce655cfSLongsheng Mou     SymbolTableCollection symbolTable;
107*2ce655cfSLongsheng Mou     SymbolUserMap userMap(symbolTable, module);
108b12bcf3fSFrederik Gossen     for (auto it : toBeErased) {
109*2ce655cfSLongsheng Mou       StringAttr oldSymbol = it.getSymNameAttr();
110*2ce655cfSLongsheng Mou       StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
111*2ce655cfSLongsheng Mou       userMap.replaceAllUsesWith(it, newSymbol);
112b12bcf3fSFrederik Gossen       it.erase();
113b12bcf3fSFrederik Gossen     }
114b12bcf3fSFrederik Gossen   }
115b12bcf3fSFrederik Gossen };
116b12bcf3fSFrederik Gossen 
117b12bcf3fSFrederik Gossen } // namespace
118b12bcf3fSFrederik Gossen 
119b12bcf3fSFrederik Gossen std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() {
120b12bcf3fSFrederik Gossen   return std::make_unique<DuplicateFunctionEliminationPass>();
121b12bcf3fSFrederik Gossen }
122b12bcf3fSFrederik Gossen 
123b12bcf3fSFrederik Gossen } // namespace mlir
124