xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (revision 2ce655cf1b029481b88b48b409d7423472856b38)
1 //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Func/Transforms/Passes.h"
11 
12 namespace mlir {
13 namespace {
14 
15 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
17 
18 // Define a notion of function equivalence that allows for reuse. Ignore the
19 // symbol name for this purpose.
20 struct DuplicateFuncOpEquivalenceInfo
21     : public llvm::DenseMapInfo<func::FuncOp> {
22 
23   static unsigned getHashValue(const func::FuncOp cFunc) {
24     if (!cFunc) {
25       return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
26     }
27 
28     // Aggregate attributes, ignoring the symbol name.
29     llvm::hash_code hash = {};
30     func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
31     StringAttr symNameAttrName = func.getSymNameAttrName();
32     for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33       StringAttr attrName = namedAttr.getName();
34       if (attrName == symNameAttrName)
35         continue;
36       hash = llvm::hash_combine(hash, namedAttr);
37     }
38 
39     // Also hash the func body.
40     func.getBody().walk([&](Operation *op) {
41       hash = llvm::hash_combine(
42           hash, OperationEquivalence::computeHash(
43                     op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
44                     /*hashResults=*/OperationEquivalence::ignoreHashValue,
45                     OperationEquivalence::IgnoreLocations));
46     });
47 
48     return hash;
49   }
50 
51   static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
52     if (lhs == rhs)
53       return true;
54     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
55         rhs == getTombstoneKey() || rhs == getEmptyKey())
56       return false;
57 
58     if (lhs.isDeclaration() || rhs.isDeclaration())
59       return false;
60 
61     // Check discardable attributes equivalence
62     if (lhs->getDiscardableAttrDictionary() !=
63         rhs->getDiscardableAttrDictionary())
64       return false;
65 
66     // Check properties equivalence, ignoring the symbol name.
67     // Make a copy, so that we can erase the symbol name and perform the
68     // comparison.
69     auto pLhs = lhs.getProperties();
70     auto pRhs = rhs.getProperties();
71     pLhs.sym_name = nullptr;
72     pRhs.sym_name = nullptr;
73     if (pLhs != pRhs)
74       return false;
75 
76     // Compare inner workings.
77     return OperationEquivalence::isRegionEquivalentTo(
78         &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
79   }
80 };
81 
82 struct DuplicateFunctionEliminationPass
83     : public impl::DuplicateFunctionEliminationPassBase<
84           DuplicateFunctionEliminationPass> {
85 
86   using DuplicateFunctionEliminationPassBase<
87       DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
88 
89   void runOnOperation() override {
90     auto module = getOperation();
91 
92     // Find unique representant per equivalent func ops.
93     DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
94     DenseMap<StringAttr, func::FuncOp> getRepresentant;
95     DenseSet<func::FuncOp> toBeErased;
96     module.walk([&](func::FuncOp f) {
97       auto [repr, inserted] = uniqueFuncOps.insert(f);
98       getRepresentant[f.getSymNameAttr()] = *repr;
99       if (!inserted) {
100         toBeErased.insert(f);
101       }
102     });
103 
104     // Update all symbol uses to reference unique func op
105     // representants and erase redundant func ops.
106     SymbolTableCollection symbolTable;
107     SymbolUserMap userMap(symbolTable, module);
108     for (auto it : toBeErased) {
109       StringAttr oldSymbol = it.getSymNameAttr();
110       StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
111       userMap.replaceAllUsesWith(it, newSymbol);
112       it.erase();
113     }
114   }
115 };
116 
117 } // namespace
118 
119 std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() {
120   return std::make_unique<DuplicateFunctionEliminationPass>();
121 }
122 
123 } // namespace mlir
124