1b12bcf3fSFrederik Gossen //===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===// 2b12bcf3fSFrederik Gossen // 3b12bcf3fSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b12bcf3fSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information. 5b12bcf3fSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b12bcf3fSFrederik Gossen // 7b12bcf3fSFrederik Gossen //===----------------------------------------------------------------------===// 8b12bcf3fSFrederik Gossen 9b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/IR/FuncOps.h" 10b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/Transforms/Passes.h" 11b12bcf3fSFrederik Gossen 12b12bcf3fSFrederik Gossen namespace mlir { 13b12bcf3fSFrederik Gossen namespace { 14b12bcf3fSFrederik Gossen 15b12bcf3fSFrederik Gossen #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS 16b12bcf3fSFrederik Gossen #include "mlir/Dialect/Func/Transforms/Passes.h.inc" 17b12bcf3fSFrederik Gossen 18b12bcf3fSFrederik Gossen // Define a notion of function equivalence that allows for reuse. Ignore the 19b12bcf3fSFrederik Gossen // symbol name for this purpose. 20b12bcf3fSFrederik Gossen struct DuplicateFuncOpEquivalenceInfo 21b12bcf3fSFrederik Gossen : public llvm::DenseMapInfo<func::FuncOp> { 22b12bcf3fSFrederik Gossen 23b12bcf3fSFrederik Gossen static unsigned getHashValue(const func::FuncOp cFunc) { 24b12bcf3fSFrederik Gossen if (!cFunc) { 25b12bcf3fSFrederik Gossen return DenseMapInfo<func::FuncOp>::getHashValue(cFunc); 26b12bcf3fSFrederik Gossen } 27b12bcf3fSFrederik Gossen 28b12bcf3fSFrederik Gossen // Aggregate attributes, ignoring the symbol name. 29b12bcf3fSFrederik Gossen llvm::hash_code hash = {}; 30b12bcf3fSFrederik Gossen func::FuncOp func = const_cast<func::FuncOp &>(cFunc); 31b12bcf3fSFrederik Gossen StringAttr symNameAttrName = func.getSymNameAttrName(); 32b12bcf3fSFrederik Gossen for (NamedAttribute namedAttr : cFunc->getAttrs()) { 33b12bcf3fSFrederik Gossen StringAttr attrName = namedAttr.getName(); 34b12bcf3fSFrederik Gossen if (attrName == symNameAttrName) 35b12bcf3fSFrederik Gossen continue; 36b12bcf3fSFrederik Gossen hash = llvm::hash_combine(hash, namedAttr); 37b12bcf3fSFrederik Gossen } 38b12bcf3fSFrederik Gossen 39b12bcf3fSFrederik Gossen // Also hash the func body. 40b12bcf3fSFrederik Gossen func.getBody().walk([&](Operation *op) { 41b12bcf3fSFrederik Gossen hash = llvm::hash_combine( 42b12bcf3fSFrederik Gossen hash, OperationEquivalence::computeHash( 43b12bcf3fSFrederik Gossen op, /*hashOperands=*/OperationEquivalence::ignoreHashValue, 44b12bcf3fSFrederik Gossen /*hashResults=*/OperationEquivalence::ignoreHashValue, 45b12bcf3fSFrederik Gossen OperationEquivalence::IgnoreLocations)); 46b12bcf3fSFrederik Gossen }); 47b12bcf3fSFrederik Gossen 48b12bcf3fSFrederik Gossen return hash; 49b12bcf3fSFrederik Gossen } 50b12bcf3fSFrederik Gossen 5127b73922SMehdi Amini static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) { 5227b73922SMehdi Amini if (lhs == rhs) 53b12bcf3fSFrederik Gossen return true; 5427b73922SMehdi Amini if (lhs == getTombstoneKey() || lhs == getEmptyKey() || 5527b73922SMehdi Amini rhs == getTombstoneKey() || rhs == getEmptyKey()) 56b12bcf3fSFrederik Gossen return false; 57*2ce655cfSLongsheng Mou 58*2ce655cfSLongsheng Mou if (lhs.isDeclaration() || rhs.isDeclaration()) 59*2ce655cfSLongsheng Mou return false; 60*2ce655cfSLongsheng Mou 6127b73922SMehdi Amini // Check discardable attributes equivalence 6227b73922SMehdi Amini if (lhs->getDiscardableAttrDictionary() != 6327b73922SMehdi Amini rhs->getDiscardableAttrDictionary()) 6427b73922SMehdi Amini return false; 65b12bcf3fSFrederik Gossen 6627b73922SMehdi Amini // Check properties equivalence, ignoring the symbol name. 6727b73922SMehdi Amini // Make a copy, so that we can erase the symbol name and perform the 6827b73922SMehdi Amini // comparison. 6927b73922SMehdi Amini auto pLhs = lhs.getProperties(); 7027b73922SMehdi Amini auto pRhs = rhs.getProperties(); 7127b73922SMehdi Amini pLhs.sym_name = nullptr; 7227b73922SMehdi Amini pRhs.sym_name = nullptr; 7327b73922SMehdi Amini if (pLhs != pRhs) 74b12bcf3fSFrederik Gossen return false; 75b12bcf3fSFrederik Gossen 76b12bcf3fSFrederik Gossen // Compare inner workings. 77b12bcf3fSFrederik Gossen return OperationEquivalence::isRegionEquivalentTo( 78b12bcf3fSFrederik Gossen &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); 79b12bcf3fSFrederik Gossen } 80b12bcf3fSFrederik Gossen }; 81b12bcf3fSFrederik Gossen 82b12bcf3fSFrederik Gossen struct DuplicateFunctionEliminationPass 83b12bcf3fSFrederik Gossen : public impl::DuplicateFunctionEliminationPassBase< 84b12bcf3fSFrederik Gossen DuplicateFunctionEliminationPass> { 85b12bcf3fSFrederik Gossen 86b12bcf3fSFrederik Gossen using DuplicateFunctionEliminationPassBase< 87b12bcf3fSFrederik Gossen DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase; 88b12bcf3fSFrederik Gossen 89b12bcf3fSFrederik Gossen void runOnOperation() override { 90b12bcf3fSFrederik Gossen auto module = getOperation(); 91b12bcf3fSFrederik Gossen 92b12bcf3fSFrederik Gossen // Find unique representant per equivalent func ops. 93b12bcf3fSFrederik Gossen DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps; 94b12bcf3fSFrederik Gossen DenseMap<StringAttr, func::FuncOp> getRepresentant; 95b12bcf3fSFrederik Gossen DenseSet<func::FuncOp> toBeErased; 96b12bcf3fSFrederik Gossen module.walk([&](func::FuncOp f) { 97b12bcf3fSFrederik Gossen auto [repr, inserted] = uniqueFuncOps.insert(f); 98b12bcf3fSFrederik Gossen getRepresentant[f.getSymNameAttr()] = *repr; 99b12bcf3fSFrederik Gossen if (!inserted) { 100b12bcf3fSFrederik Gossen toBeErased.insert(f); 101b12bcf3fSFrederik Gossen } 102b12bcf3fSFrederik Gossen }); 103b12bcf3fSFrederik Gossen 104*2ce655cfSLongsheng Mou // Update all symbol uses to reference unique func op 105*2ce655cfSLongsheng Mou // representants and erase redundant func ops. 106*2ce655cfSLongsheng Mou SymbolTableCollection symbolTable; 107*2ce655cfSLongsheng Mou SymbolUserMap userMap(symbolTable, module); 108b12bcf3fSFrederik Gossen for (auto it : toBeErased) { 109*2ce655cfSLongsheng Mou StringAttr oldSymbol = it.getSymNameAttr(); 110*2ce655cfSLongsheng Mou StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr(); 111*2ce655cfSLongsheng Mou userMap.replaceAllUsesWith(it, newSymbol); 112b12bcf3fSFrederik Gossen it.erase(); 113b12bcf3fSFrederik Gossen } 114b12bcf3fSFrederik Gossen } 115b12bcf3fSFrederik Gossen }; 116b12bcf3fSFrederik Gossen 117b12bcf3fSFrederik Gossen } // namespace 118b12bcf3fSFrederik Gossen 119b12bcf3fSFrederik Gossen std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() { 120b12bcf3fSFrederik Gossen return std::make_unique<DuplicateFunctionEliminationPass>(); 121b12bcf3fSFrederik Gossen } 122b12bcf3fSFrederik Gossen 123b12bcf3fSFrederik Gossen } // namespace mlir 124