11982afb1SRiver Riddle //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 21982afb1SRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61982afb1SRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 81982afb1SRiver Riddle // 91982afb1SRiver Riddle // This file defines various operation fold utilities. These utilities are 101982afb1SRiver Riddle // intended to be used by passes to unify and simply their logic. 111982afb1SRiver Riddle // 121982afb1SRiver Riddle //===----------------------------------------------------------------------===// 131982afb1SRiver Riddle 141982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h" 151982afb1SRiver Riddle 161982afb1SRiver Riddle #include "mlir/IR/Builders.h" 171982afb1SRiver Riddle #include "mlir/IR/Matchers.h" 181982afb1SRiver Riddle #include "mlir/IR/Operation.h" 191982afb1SRiver Riddle 201982afb1SRiver Riddle using namespace mlir; 211982afb1SRiver Riddle 2266ed7d6dSRiver Riddle /// Given an operation, find the parent region that folded constants should be 2366ed7d6dSRiver Riddle /// inserted into. 24b28e3db8SMehdi Amini static Region * 25b28e3db8SMehdi Amini getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, 2604f2b717SMaheshRavishankar Block *insertionBlock) { 2704f2b717SMaheshRavishankar while (Region *region = insertionBlock->getParent()) { 2866ed7d6dSRiver Riddle // Insert in this region for any of the following scenarios: 2966ed7d6dSRiver Riddle // * The parent is unregistered, or is known to be isolated from above. 3066ed7d6dSRiver Riddle // * The parent is a top-level operation. 311e429540SRiver Riddle auto *parentOp = region->getParentOp(); 32fe7c0d90SRiver Riddle if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 33474e3541SRiver Riddle !parentOp->getBlock()) 3466ed7d6dSRiver Riddle return region; 356563b1c4SRiver Riddle 366563b1c4SRiver Riddle // Otherwise, check if this region is a desired insertion region. 376563b1c4SRiver Riddle auto *interface = interfaces.getInterfaceFor(parentOp); 386563b1c4SRiver Riddle if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 396563b1c4SRiver Riddle return region; 406563b1c4SRiver Riddle 4166ed7d6dSRiver Riddle // Traverse up the parent looking for an insertion region. 4204f2b717SMaheshRavishankar insertionBlock = parentOp->getBlock(); 4366ed7d6dSRiver Riddle } 4466ed7d6dSRiver Riddle llvm_unreachable("expected valid insertion region"); 4566ed7d6dSRiver Riddle } 4666ed7d6dSRiver Riddle 4766ed7d6dSRiver Riddle /// A utility function used to materialize a constant for a given attribute and 4866ed7d6dSRiver Riddle /// type. On success, a valid constant value is returned. Otherwise, null is 4966ed7d6dSRiver Riddle /// returned 5066ed7d6dSRiver Riddle static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 5166ed7d6dSRiver Riddle Attribute value, Type type, 5266ed7d6dSRiver Riddle Location loc) { 5366ed7d6dSRiver Riddle auto insertPt = builder.getInsertionPoint(); 5466ed7d6dSRiver Riddle (void)insertPt; 5566ed7d6dSRiver Riddle 5666ed7d6dSRiver Riddle // Ask the dialect to materialize a constant operation for this value. 5766ed7d6dSRiver Riddle if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 5866ed7d6dSRiver Riddle assert(insertPt == builder.getInsertionPoint()); 59907403f3SRiver Riddle assert(matchPattern(constOp, m_Constant())); 6066ed7d6dSRiver Riddle return constOp; 6166ed7d6dSRiver Riddle } 62444822d7SSean Silva 6366ed7d6dSRiver Riddle return nullptr; 6466ed7d6dSRiver Riddle } 6566ed7d6dSRiver Riddle 669b4a02c1SRiver Riddle //===----------------------------------------------------------------------===// 679b4a02c1SRiver Riddle // OperationFolder 689b4a02c1SRiver Riddle //===----------------------------------------------------------------------===// 691982afb1SRiver Riddle 709bdfa8dfSMatthias Springer LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { 71cbcb12fdSUday Bondhugula if (inPlaceUpdate) 72cbcb12fdSUday Bondhugula *inPlaceUpdate = false; 73cbcb12fdSUday Bondhugula 74bcacef1aSRiver Riddle // If this is a unique'd constant, return failure as we know that it has 75bcacef1aSRiver Riddle // already been folded. 76af371f9fSRiver Riddle if (isFolderOwnedConstant(op)) { 77af371f9fSRiver Riddle // Check to see if we should rehoist, i.e. if a non-constant operation was 78af371f9fSRiver Riddle // inserted before this one. 79af371f9fSRiver Riddle Block *opBlock = op->getBlock(); 8034a65980SBilly Zhu if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) { 81af371f9fSRiver Riddle op->moveBefore(&opBlock->front()); 8234a65980SBilly Zhu op->setLoc(erasedFoldedLocation); 8334a65980SBilly Zhu } 84bcacef1aSRiver Riddle return failure(); 85af371f9fSRiver Riddle } 861982afb1SRiver Riddle 879b4a02c1SRiver Riddle // Try to fold the operation. 88e62a6956SRiver Riddle SmallVector<Value, 8> results; 899297b9f8SMatthias Springer if (failed(tryToFold(op, results))) 909b4a02c1SRiver Riddle return failure(); 919b4a02c1SRiver Riddle 929b4a02c1SRiver Riddle // Check to see if the operation was just updated in place. 93cbcb12fdSUday Bondhugula if (results.empty()) { 94cbcb12fdSUday Bondhugula if (inPlaceUpdate) 95cbcb12fdSUday Bondhugula *inPlaceUpdate = true; 969297b9f8SMatthias Springer if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>( 979297b9f8SMatthias Springer rewriter.getListener())) { 989297b9f8SMatthias Springer // Folding API does not notify listeners, so we have to notify manually. 999297b9f8SMatthias Springer rewriteListener->notifyOperationModified(op); 1009297b9f8SMatthias Springer } 1019b4a02c1SRiver Riddle return success(); 102cbcb12fdSUday Bondhugula } 1039b4a02c1SRiver Riddle 1049bdfa8dfSMatthias Springer // Constant folding succeeded. Replace all of the result values and erase the 1059bdfa8dfSMatthias Springer // operation. 1069297b9f8SMatthias Springer notifyRemoval(op); 1079297b9f8SMatthias Springer rewriter.replaceOp(op, results); 1089b4a02c1SRiver Riddle return success(); 1099b4a02c1SRiver Riddle } 1109b4a02c1SRiver Riddle 111af371f9fSRiver Riddle bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { 112af371f9fSRiver Riddle Block *opBlock = op->getBlock(); 113af371f9fSRiver Riddle 114af371f9fSRiver Riddle // If this is a constant we unique'd, we don't need to insert, but we can 115af371f9fSRiver Riddle // check to see if we should rehoist it. 116af371f9fSRiver Riddle if (isFolderOwnedConstant(op)) { 11734a65980SBilly Zhu if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) { 118af371f9fSRiver Riddle op->moveBefore(&opBlock->front()); 11934a65980SBilly Zhu op->setLoc(erasedFoldedLocation); 12034a65980SBilly Zhu } 121af371f9fSRiver Riddle return true; 122af371f9fSRiver Riddle } 123af371f9fSRiver Riddle 124af371f9fSRiver Riddle // Get the constant value of the op if necessary. 125af371f9fSRiver Riddle if (!constValue) { 126af371f9fSRiver Riddle matchPattern(op, m_Constant(&constValue)); 127af371f9fSRiver Riddle assert(constValue && "expected `op` to be a constant"); 128af371f9fSRiver Riddle } else { 129af371f9fSRiver Riddle // Ensure that the provided constant was actually correct. 130af371f9fSRiver Riddle #ifndef NDEBUG 131af371f9fSRiver Riddle Attribute expectedValue; 132af371f9fSRiver Riddle matchPattern(op, m_Constant(&expectedValue)); 133af371f9fSRiver Riddle assert( 134af371f9fSRiver Riddle expectedValue == constValue && 135af371f9fSRiver Riddle "provided constant value was not the expected value of the constant"); 136af371f9fSRiver Riddle #endif 137af371f9fSRiver Riddle } 138af371f9fSRiver Riddle 139af371f9fSRiver Riddle // Check for an existing constant operation for the attribute value. 140af371f9fSRiver Riddle Region *insertRegion = getInsertionRegion(interfaces, opBlock); 141af371f9fSRiver Riddle auto &uniquedConstants = foldScopes[insertRegion]; 142af371f9fSRiver Riddle Operation *&folderConstOp = uniquedConstants[std::make_tuple( 143af371f9fSRiver Riddle op->getDialect(), constValue, *op->result_type_begin())]; 144af371f9fSRiver Riddle 145af371f9fSRiver Riddle // If there is an existing constant, replace `op`. 146af371f9fSRiver Riddle if (folderConstOp) { 1479297b9f8SMatthias Springer notifyRemoval(op); 1489297b9f8SMatthias Springer rewriter.replaceOp(op, folderConstOp->getResults()); 14934a65980SBilly Zhu folderConstOp->setLoc(erasedFoldedLocation); 150af371f9fSRiver Riddle return false; 151af371f9fSRiver Riddle } 152af371f9fSRiver Riddle 153af371f9fSRiver Riddle // Otherwise, we insert `op`. If `op` is in the insertion block and is either 154af371f9fSRiver Riddle // already at the front of the block, or the previous operation is already a 155af371f9fSRiver Riddle // constant we unique'd (i.e. one we inserted), then we don't need to do 156af371f9fSRiver Riddle // anything. Otherwise, we move the constant to the insertion block. 157af371f9fSRiver Riddle Block *insertBlock = &insertRegion->front(); 158af371f9fSRiver Riddle if (opBlock != insertBlock || (&insertBlock->front() != op && 15934a65980SBilly Zhu !isFolderOwnedConstant(op->getPrevNode()))) { 160af371f9fSRiver Riddle op->moveBefore(&insertBlock->front()); 16134a65980SBilly Zhu op->setLoc(erasedFoldedLocation); 16234a65980SBilly Zhu } 163af371f9fSRiver Riddle 164af371f9fSRiver Riddle folderConstOp = op; 165af371f9fSRiver Riddle referencedDialects[op].push_back(op->getDialect()); 166af371f9fSRiver Riddle return true; 167af371f9fSRiver Riddle } 168af371f9fSRiver Riddle 169bcacef1aSRiver Riddle /// Notifies that the given constant `op` should be remove from this 170bcacef1aSRiver Riddle /// OperationFolder's internal bookkeeping. 171bcacef1aSRiver Riddle void OperationFolder::notifyRemoval(Operation *op) { 172bcacef1aSRiver Riddle // Check to see if this operation is uniqued within the folder. 173bcacef1aSRiver Riddle auto it = referencedDialects.find(op); 174bcacef1aSRiver Riddle if (it == referencedDialects.end()) 175bcacef1aSRiver Riddle return; 176bcacef1aSRiver Riddle 177bcacef1aSRiver Riddle // Get the constant value for this operation, this is the value that was used 178bcacef1aSRiver Riddle // to unique the operation internally. 179bcacef1aSRiver Riddle Attribute constValue; 180bcacef1aSRiver Riddle matchPattern(op, m_Constant(&constValue)); 181bcacef1aSRiver Riddle assert(constValue); 182bcacef1aSRiver Riddle 18366ed7d6dSRiver Riddle // Get the constant map that this operation was uniqued in. 18404f2b717SMaheshRavishankar auto &uniquedConstants = 18504f2b717SMaheshRavishankar foldScopes[getInsertionRegion(interfaces, op->getBlock())]; 18666ed7d6dSRiver Riddle 187bcacef1aSRiver Riddle // Erase all of the references to this operation. 1882bdf33ccSRiver Riddle auto type = op->getResult(0).getType(); 189bcacef1aSRiver Riddle for (auto *dialect : it->second) 190bcacef1aSRiver Riddle uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 191bcacef1aSRiver Riddle referencedDialects.erase(it); 192bcacef1aSRiver Riddle } 193bcacef1aSRiver Riddle 1940ddba0bdSRiver Riddle /// Clear out any constants cached inside of the folder. 1950ddba0bdSRiver Riddle void OperationFolder::clear() { 1960ddba0bdSRiver Riddle foldScopes.clear(); 1970ddba0bdSRiver Riddle referencedDialects.clear(); 1980ddba0bdSRiver Riddle } 1990ddba0bdSRiver Riddle 200152d29ccSRiver Riddle /// Get or create a constant using the given builder. On success this returns 201152d29ccSRiver Riddle /// the constant operation, nullptr otherwise. 2029297b9f8SMatthias Springer Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, 20334a65980SBilly Zhu Attribute value, Type type) { 2049297b9f8SMatthias Springer // Find an insertion point for the constant. 2059297b9f8SMatthias Springer auto *insertRegion = getInsertionRegion(interfaces, block); 206152d29ccSRiver Riddle auto &entry = insertRegion->front(); 207b613a540SMatthias Springer rewriter.setInsertionPointToStart(&entry); 208152d29ccSRiver Riddle 209152d29ccSRiver Riddle // Get the constant map for the insertion region of this operation. 21034a65980SBilly Zhu // Use erased location since the op is being built at the front of block. 211152d29ccSRiver Riddle auto &uniquedConstants = foldScopes[insertRegion]; 21234a65980SBilly Zhu Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value, 21334a65980SBilly Zhu type, erasedFoldedLocation); 214152d29ccSRiver Riddle return constOp ? constOp->getResult(0) : Value(); 215152d29ccSRiver Riddle } 216152d29ccSRiver Riddle 217af371f9fSRiver Riddle bool OperationFolder::isFolderOwnedConstant(Operation *op) const { 218af371f9fSRiver Riddle return referencedDialects.count(op); 219af371f9fSRiver Riddle } 220af371f9fSRiver Riddle 2219b4a02c1SRiver Riddle /// Tries to perform folding on the given `op`. If successful, populates 2220560f153SRiver Riddle /// `results` with the results of the folding. 2239297b9f8SMatthias Springer LogicalResult OperationFolder::tryToFold(Operation *op, 2249bdfa8dfSMatthias Springer SmallVectorImpl<Value> &results) { 225af371f9fSRiver Riddle SmallVector<OpFoldResult, 8> foldResults; 22621379151SMatthias Springer if (failed(op->fold(foldResults)) || 2279297b9f8SMatthias Springer failed(processFoldResults(op, results, foldResults))) 228dd115e5aSMatthias Springer return failure(); 229af371f9fSRiver Riddle return success(); 230af371f9fSRiver Riddle } 2311982afb1SRiver Riddle 2329bdfa8dfSMatthias Springer LogicalResult 2339297b9f8SMatthias Springer OperationFolder::processFoldResults(Operation *op, 2349bdfa8dfSMatthias Springer SmallVectorImpl<Value> &results, 2359bdfa8dfSMatthias Springer ArrayRef<OpFoldResult> foldResults) { 2361982afb1SRiver Riddle // Check to see if the operation was just updated in place. 2379b4a02c1SRiver Riddle if (foldResults.empty()) 2381982afb1SRiver Riddle return success(); 2399b4a02c1SRiver Riddle assert(foldResults.size() == op->getNumResults()); 2401982afb1SRiver Riddle 24166ed7d6dSRiver Riddle // Create a builder to insert new operations into the entry block of the 24266ed7d6dSRiver Riddle // insertion region. 2439297b9f8SMatthias Springer auto *insertRegion = getInsertionRegion(interfaces, op->getBlock()); 2446563b1c4SRiver Riddle auto &entry = insertRegion->front(); 245b613a540SMatthias Springer rewriter.setInsertionPointToStart(&entry); 24666ed7d6dSRiver Riddle 24766ed7d6dSRiver Riddle // Get the constant map for the insertion region of this operation. 2486563b1c4SRiver Riddle auto &uniquedConstants = foldScopes[insertRegion]; 249bcacef1aSRiver Riddle 2501982afb1SRiver Riddle // Create the result constants and replace the results. 251bcacef1aSRiver Riddle auto *dialect = op->getDialect(); 2521982afb1SRiver Riddle for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 2539b4a02c1SRiver Riddle assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 2541982afb1SRiver Riddle 2551982afb1SRiver Riddle // Check if the result was an SSA value. 25668f58812STres Popp if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) { 2579b4a02c1SRiver Riddle results.emplace_back(repl); 2581982afb1SRiver Riddle continue; 2591982afb1SRiver Riddle } 2601982afb1SRiver Riddle 261bcacef1aSRiver Riddle // Check to see if there is a canonicalized version of this constant. 26235807bc4SRiver Riddle auto res = op->getResult(i); 263*4f4e2abbSKazu Hirata Attribute attrRepl = cast<Attribute>(foldResults[i]); 26434a65980SBilly Zhu if (auto *constOp = 26534a65980SBilly Zhu tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl, 26634a65980SBilly Zhu res.getType(), erasedFoldedLocation)) { 267e4635e63SRiver Riddle // Ensure that this constant dominates the operation we are replacing it 268e4635e63SRiver Riddle // with. This may not automatically happen if the operation being folded 269e4635e63SRiver Riddle // was inserted before the constant within the insertion block. 27040a89da6SChris Lattner Block *opBlock = op->getBlock(); 27140a89da6SChris Lattner if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) 27240a89da6SChris Lattner constOp->moveBefore(&opBlock->front()); 273e4635e63SRiver Riddle 274bcacef1aSRiver Riddle results.push_back(constOp->getResult(0)); 275bcacef1aSRiver Riddle continue; 2761982afb1SRiver Riddle } 277bcacef1aSRiver Riddle // If materialization fails, cleanup any operations generated for the 278bcacef1aSRiver Riddle // previous results and return failure. 279bcacef1aSRiver Riddle for (Operation &op : llvm::make_early_inc_range( 2809297b9f8SMatthias Springer llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) { 2819297b9f8SMatthias Springer notifyRemoval(&op); 2829297b9f8SMatthias Springer rewriter.eraseOp(&op); 2839297b9f8SMatthias Springer } 2849bdfa8dfSMatthias Springer 285af371f9fSRiver Riddle results.clear(); 286bcacef1aSRiver Riddle return failure(); 287bcacef1aSRiver Riddle } 288bcacef1aSRiver Riddle 2891982afb1SRiver Riddle return success(); 2901982afb1SRiver Riddle } 2911982afb1SRiver Riddle 292bcacef1aSRiver Riddle /// Try to get or create a new constant entry. On success this returns the 293bcacef1aSRiver Riddle /// constant operation value, nullptr otherwise. 2949297b9f8SMatthias Springer Operation * 2959297b9f8SMatthias Springer OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, 2969297b9f8SMatthias Springer Dialect *dialect, Attribute value, 2979297b9f8SMatthias Springer Type type, Location loc) { 298bcacef1aSRiver Riddle // Check if an existing mapping already exists. 299bcacef1aSRiver Riddle auto constKey = std::make_tuple(dialect, value, type); 300648f34a2SChris Lattner Operation *&constOp = uniquedConstants[constKey]; 30134a65980SBilly Zhu if (constOp) { 30234a65980SBilly Zhu if (loc != constOp->getLoc()) 30334a65980SBilly Zhu constOp->setLoc(erasedFoldedLocation); 304648f34a2SChris Lattner return constOp; 30534a65980SBilly Zhu } 3061982afb1SRiver Riddle 307bcacef1aSRiver Riddle // If one doesn't exist, try to materialize one. 3089297b9f8SMatthias Springer if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc))) 309bcacef1aSRiver Riddle return nullptr; 3101982afb1SRiver Riddle 311bcacef1aSRiver Riddle // Check to see if the generated constant is in the expected dialect. 312648f34a2SChris Lattner auto *newDialect = constOp->getDialect(); 313bcacef1aSRiver Riddle if (newDialect == dialect) { 314648f34a2SChris Lattner referencedDialects[constOp].push_back(dialect); 315648f34a2SChris Lattner return constOp; 3161982afb1SRiver Riddle } 3171982afb1SRiver Riddle 318bcacef1aSRiver Riddle // If it isn't, then we also need to make sure that the mapping for the new 319bcacef1aSRiver Riddle // dialect is valid. 320bcacef1aSRiver Riddle auto newKey = std::make_tuple(newDialect, value, type); 3211982afb1SRiver Riddle 322bcacef1aSRiver Riddle // If an existing operation in the new dialect already exists, delete the 323bcacef1aSRiver Riddle // materialized operation in favor of the existing one. 324bcacef1aSRiver Riddle if (auto *existingOp = uniquedConstants.lookup(newKey)) { 3259297b9f8SMatthias Springer notifyRemoval(constOp); 3269297b9f8SMatthias Springer rewriter.eraseOp(constOp); 327bcacef1aSRiver Riddle referencedDialects[existingOp].push_back(dialect); 32834a65980SBilly Zhu if (loc != existingOp->getLoc()) 32934a65980SBilly Zhu existingOp->setLoc(erasedFoldedLocation); 330648f34a2SChris Lattner return constOp = existingOp; 3311982afb1SRiver Riddle } 3321982afb1SRiver Riddle 333bcacef1aSRiver Riddle // Otherwise, update the new dialect to the materialized operation. 334648f34a2SChris Lattner referencedDialects[constOp].assign({dialect, newDialect}); 335648f34a2SChris Lattner auto newIt = uniquedConstants.insert({newKey, constOp}); 336bcacef1aSRiver Riddle return newIt.first->second; 3371982afb1SRiver Riddle } 338