1 //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/Dialect/Func/Transforms/Passes.h" 11 12 namespace mlir { 13 namespace { 14 15 #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS 16 #include "mlir/Dialect/Func/Transforms/Passes.h.inc" 17 18 // Define a notion of function equivalence that allows for reuse. Ignore the 19 // symbol name for this purpose. 20 struct DuplicateFuncOpEquivalenceInfo 21 : public llvm::DenseMapInfo<func::FuncOp> { 22 23 static unsigned getHashValue(const func::FuncOp cFunc) { 24 if (!cFunc) { 25 return DenseMapInfo<func::FuncOp>::getHashValue(cFunc); 26 } 27 28 // Aggregate attributes, ignoring the symbol name. 29 llvm::hash_code hash = {}; 30 func::FuncOp func = const_cast<func::FuncOp &>(cFunc); 31 StringAttr symNameAttrName = func.getSymNameAttrName(); 32 for (NamedAttribute namedAttr : cFunc->getAttrs()) { 33 StringAttr attrName = namedAttr.getName(); 34 if (attrName == symNameAttrName) 35 continue; 36 hash = llvm::hash_combine(hash, namedAttr); 37 } 38 39 // Also hash the func body. 40 func.getBody().walk([&](Operation *op) { 41 hash = llvm::hash_combine( 42 hash, OperationEquivalence::computeHash( 43 op, /*hashOperands=*/OperationEquivalence::ignoreHashValue, 44 /*hashResults=*/OperationEquivalence::ignoreHashValue, 45 OperationEquivalence::IgnoreLocations)); 46 }); 47 48 return hash; 49 } 50 51 static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) { 52 if (lhs == rhs) 53 return true; 54 if (lhs == getTombstoneKey() || lhs == getEmptyKey() || 55 rhs == getTombstoneKey() || rhs == getEmptyKey()) 56 return false; 57 58 if (lhs.isDeclaration() || rhs.isDeclaration()) 59 return false; 60 61 // Check discardable attributes equivalence 62 if (lhs->getDiscardableAttrDictionary() != 63 rhs->getDiscardableAttrDictionary()) 64 return false; 65 66 // Check properties equivalence, ignoring the symbol name. 67 // Make a copy, so that we can erase the symbol name and perform the 68 // comparison. 69 auto pLhs = lhs.getProperties(); 70 auto pRhs = rhs.getProperties(); 71 pLhs.sym_name = nullptr; 72 pRhs.sym_name = nullptr; 73 if (pLhs != pRhs) 74 return false; 75 76 // Compare inner workings. 77 return OperationEquivalence::isRegionEquivalentTo( 78 &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); 79 } 80 }; 81 82 struct DuplicateFunctionEliminationPass 83 : public impl::DuplicateFunctionEliminationPassBase< 84 DuplicateFunctionEliminationPass> { 85 86 using DuplicateFunctionEliminationPassBase< 87 DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase; 88 89 void runOnOperation() override { 90 auto module = getOperation(); 91 92 // Find unique representant per equivalent func ops. 93 DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps; 94 DenseMap<StringAttr, func::FuncOp> getRepresentant; 95 DenseSet<func::FuncOp> toBeErased; 96 module.walk([&](func::FuncOp f) { 97 auto [repr, inserted] = uniqueFuncOps.insert(f); 98 getRepresentant[f.getSymNameAttr()] = *repr; 99 if (!inserted) { 100 toBeErased.insert(f); 101 } 102 }); 103 104 // Update all symbol uses to reference unique func op 105 // representants and erase redundant func ops. 106 SymbolTableCollection symbolTable; 107 SymbolUserMap userMap(symbolTable, module); 108 for (auto it : toBeErased) { 109 StringAttr oldSymbol = it.getSymNameAttr(); 110 StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr(); 111 userMap.replaceAllUsesWith(it, newSymbol); 112 it.erase(); 113 } 114 } 115 }; 116 117 } // namespace 118 119 std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() { 120 return std::make_unique<DuplicateFunctionEliminationPass>(); 121 } 122 123 } // namespace mlir 124