xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
11982afb1SRiver Riddle //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
21982afb1SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61982afb1SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
81982afb1SRiver Riddle //
91982afb1SRiver Riddle // This file defines various operation fold utilities. These utilities are
101982afb1SRiver Riddle // intended to be used by passes to unify and simply their logic.
111982afb1SRiver Riddle //
121982afb1SRiver Riddle //===----------------------------------------------------------------------===//
131982afb1SRiver Riddle 
141982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
151982afb1SRiver Riddle 
161982afb1SRiver Riddle #include "mlir/IR/Builders.h"
171982afb1SRiver Riddle #include "mlir/IR/Matchers.h"
181982afb1SRiver Riddle #include "mlir/IR/Operation.h"
191982afb1SRiver Riddle 
201982afb1SRiver Riddle using namespace mlir;
211982afb1SRiver Riddle 
2266ed7d6dSRiver Riddle /// Given an operation, find the parent region that folded constants should be
2366ed7d6dSRiver Riddle /// inserted into.
24b28e3db8SMehdi Amini static Region *
25b28e3db8SMehdi Amini getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
2604f2b717SMaheshRavishankar                    Block *insertionBlock) {
2704f2b717SMaheshRavishankar   while (Region *region = insertionBlock->getParent()) {
2866ed7d6dSRiver Riddle     // Insert in this region for any of the following scenarios:
2966ed7d6dSRiver Riddle     //  * The parent is unregistered, or is known to be isolated from above.
3066ed7d6dSRiver Riddle     //  * The parent is a top-level operation.
311e429540SRiver Riddle     auto *parentOp = region->getParentOp();
32fe7c0d90SRiver Riddle     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33474e3541SRiver Riddle         !parentOp->getBlock())
3466ed7d6dSRiver Riddle       return region;
356563b1c4SRiver Riddle 
366563b1c4SRiver Riddle     // Otherwise, check if this region is a desired insertion region.
376563b1c4SRiver Riddle     auto *interface = interfaces.getInterfaceFor(parentOp);
386563b1c4SRiver Riddle     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
396563b1c4SRiver Riddle       return region;
406563b1c4SRiver Riddle 
4166ed7d6dSRiver Riddle     // Traverse up the parent looking for an insertion region.
4204f2b717SMaheshRavishankar     insertionBlock = parentOp->getBlock();
4366ed7d6dSRiver Riddle   }
4466ed7d6dSRiver Riddle   llvm_unreachable("expected valid insertion region");
4566ed7d6dSRiver Riddle }
4666ed7d6dSRiver Riddle 
4766ed7d6dSRiver Riddle /// A utility function used to materialize a constant for a given attribute and
4866ed7d6dSRiver Riddle /// type. On success, a valid constant value is returned. Otherwise, null is
4966ed7d6dSRiver Riddle /// returned
5066ed7d6dSRiver Riddle static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
5166ed7d6dSRiver Riddle                                       Attribute value, Type type,
5266ed7d6dSRiver Riddle                                       Location loc) {
5366ed7d6dSRiver Riddle   auto insertPt = builder.getInsertionPoint();
5466ed7d6dSRiver Riddle   (void)insertPt;
5566ed7d6dSRiver Riddle 
5666ed7d6dSRiver Riddle   // Ask the dialect to materialize a constant operation for this value.
5766ed7d6dSRiver Riddle   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
5866ed7d6dSRiver Riddle     assert(insertPt == builder.getInsertionPoint());
59907403f3SRiver Riddle     assert(matchPattern(constOp, m_Constant()));
6066ed7d6dSRiver Riddle     return constOp;
6166ed7d6dSRiver Riddle   }
62444822d7SSean Silva 
6366ed7d6dSRiver Riddle   return nullptr;
6466ed7d6dSRiver Riddle }
6566ed7d6dSRiver Riddle 
669b4a02c1SRiver Riddle //===----------------------------------------------------------------------===//
679b4a02c1SRiver Riddle // OperationFolder
689b4a02c1SRiver Riddle //===----------------------------------------------------------------------===//
691982afb1SRiver Riddle 
709bdfa8dfSMatthias Springer LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71cbcb12fdSUday Bondhugula   if (inPlaceUpdate)
72cbcb12fdSUday Bondhugula     *inPlaceUpdate = false;
73cbcb12fdSUday Bondhugula 
74bcacef1aSRiver Riddle   // If this is a unique'd constant, return failure as we know that it has
75bcacef1aSRiver Riddle   // already been folded.
76af371f9fSRiver Riddle   if (isFolderOwnedConstant(op)) {
77af371f9fSRiver Riddle     // Check to see if we should rehoist, i.e. if a non-constant operation was
78af371f9fSRiver Riddle     // inserted before this one.
79af371f9fSRiver Riddle     Block *opBlock = op->getBlock();
8034a65980SBilly Zhu     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
81af371f9fSRiver Riddle       op->moveBefore(&opBlock->front());
8234a65980SBilly Zhu       op->setLoc(erasedFoldedLocation);
8334a65980SBilly Zhu     }
84bcacef1aSRiver Riddle     return failure();
85af371f9fSRiver Riddle   }
861982afb1SRiver Riddle 
879b4a02c1SRiver Riddle   // Try to fold the operation.
88e62a6956SRiver Riddle   SmallVector<Value, 8> results;
899297b9f8SMatthias Springer   if (failed(tryToFold(op, results)))
909b4a02c1SRiver Riddle     return failure();
919b4a02c1SRiver Riddle 
929b4a02c1SRiver Riddle   // Check to see if the operation was just updated in place.
93cbcb12fdSUday Bondhugula   if (results.empty()) {
94cbcb12fdSUday Bondhugula     if (inPlaceUpdate)
95cbcb12fdSUday Bondhugula       *inPlaceUpdate = true;
969297b9f8SMatthias Springer     if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
979297b9f8SMatthias Springer             rewriter.getListener())) {
989297b9f8SMatthias Springer       // Folding API does not notify listeners, so we have to notify manually.
999297b9f8SMatthias Springer       rewriteListener->notifyOperationModified(op);
1009297b9f8SMatthias Springer     }
1019b4a02c1SRiver Riddle     return success();
102cbcb12fdSUday Bondhugula   }
1039b4a02c1SRiver Riddle 
1049bdfa8dfSMatthias Springer   // Constant folding succeeded. Replace all of the result values and erase the
1059bdfa8dfSMatthias Springer   // operation.
1069297b9f8SMatthias Springer   notifyRemoval(op);
1079297b9f8SMatthias Springer   rewriter.replaceOp(op, results);
1089b4a02c1SRiver Riddle   return success();
1099b4a02c1SRiver Riddle }
1109b4a02c1SRiver Riddle 
111af371f9fSRiver Riddle bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
112af371f9fSRiver Riddle   Block *opBlock = op->getBlock();
113af371f9fSRiver Riddle 
114af371f9fSRiver Riddle   // If this is a constant we unique'd, we don't need to insert, but we can
115af371f9fSRiver Riddle   // check to see if we should rehoist it.
116af371f9fSRiver Riddle   if (isFolderOwnedConstant(op)) {
11734a65980SBilly Zhu     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
118af371f9fSRiver Riddle       op->moveBefore(&opBlock->front());
11934a65980SBilly Zhu       op->setLoc(erasedFoldedLocation);
12034a65980SBilly Zhu     }
121af371f9fSRiver Riddle     return true;
122af371f9fSRiver Riddle   }
123af371f9fSRiver Riddle 
124af371f9fSRiver Riddle   // Get the constant value of the op if necessary.
125af371f9fSRiver Riddle   if (!constValue) {
126af371f9fSRiver Riddle     matchPattern(op, m_Constant(&constValue));
127af371f9fSRiver Riddle     assert(constValue && "expected `op` to be a constant");
128af371f9fSRiver Riddle   } else {
129af371f9fSRiver Riddle     // Ensure that the provided constant was actually correct.
130af371f9fSRiver Riddle #ifndef NDEBUG
131af371f9fSRiver Riddle     Attribute expectedValue;
132af371f9fSRiver Riddle     matchPattern(op, m_Constant(&expectedValue));
133af371f9fSRiver Riddle     assert(
134af371f9fSRiver Riddle         expectedValue == constValue &&
135af371f9fSRiver Riddle         "provided constant value was not the expected value of the constant");
136af371f9fSRiver Riddle #endif
137af371f9fSRiver Riddle   }
138af371f9fSRiver Riddle 
139af371f9fSRiver Riddle   // Check for an existing constant operation for the attribute value.
140af371f9fSRiver Riddle   Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141af371f9fSRiver Riddle   auto &uniquedConstants = foldScopes[insertRegion];
142af371f9fSRiver Riddle   Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143af371f9fSRiver Riddle       op->getDialect(), constValue, *op->result_type_begin())];
144af371f9fSRiver Riddle 
145af371f9fSRiver Riddle   // If there is an existing constant, replace `op`.
146af371f9fSRiver Riddle   if (folderConstOp) {
1479297b9f8SMatthias Springer     notifyRemoval(op);
1489297b9f8SMatthias Springer     rewriter.replaceOp(op, folderConstOp->getResults());
14934a65980SBilly Zhu     folderConstOp->setLoc(erasedFoldedLocation);
150af371f9fSRiver Riddle     return false;
151af371f9fSRiver Riddle   }
152af371f9fSRiver Riddle 
153af371f9fSRiver Riddle   // Otherwise, we insert `op`. If `op` is in the insertion block and is either
154af371f9fSRiver Riddle   // already at the front of the block, or the previous operation is already a
155af371f9fSRiver Riddle   // constant we unique'd (i.e. one we inserted), then we don't need to do
156af371f9fSRiver Riddle   // anything. Otherwise, we move the constant to the insertion block.
157af371f9fSRiver Riddle   Block *insertBlock = &insertRegion->front();
158af371f9fSRiver Riddle   if (opBlock != insertBlock || (&insertBlock->front() != op &&
15934a65980SBilly Zhu                                  !isFolderOwnedConstant(op->getPrevNode()))) {
160af371f9fSRiver Riddle     op->moveBefore(&insertBlock->front());
16134a65980SBilly Zhu     op->setLoc(erasedFoldedLocation);
16234a65980SBilly Zhu   }
163af371f9fSRiver Riddle 
164af371f9fSRiver Riddle   folderConstOp = op;
165af371f9fSRiver Riddle   referencedDialects[op].push_back(op->getDialect());
166af371f9fSRiver Riddle   return true;
167af371f9fSRiver Riddle }
168af371f9fSRiver Riddle 
169bcacef1aSRiver Riddle /// Notifies that the given constant `op` should be remove from this
170bcacef1aSRiver Riddle /// OperationFolder's internal bookkeeping.
171bcacef1aSRiver Riddle void OperationFolder::notifyRemoval(Operation *op) {
172bcacef1aSRiver Riddle   // Check to see if this operation is uniqued within the folder.
173bcacef1aSRiver Riddle   auto it = referencedDialects.find(op);
174bcacef1aSRiver Riddle   if (it == referencedDialects.end())
175bcacef1aSRiver Riddle     return;
176bcacef1aSRiver Riddle 
177bcacef1aSRiver Riddle   // Get the constant value for this operation, this is the value that was used
178bcacef1aSRiver Riddle   // to unique the operation internally.
179bcacef1aSRiver Riddle   Attribute constValue;
180bcacef1aSRiver Riddle   matchPattern(op, m_Constant(&constValue));
181bcacef1aSRiver Riddle   assert(constValue);
182bcacef1aSRiver Riddle 
18366ed7d6dSRiver Riddle   // Get the constant map that this operation was uniqued in.
18404f2b717SMaheshRavishankar   auto &uniquedConstants =
18504f2b717SMaheshRavishankar       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
18666ed7d6dSRiver Riddle 
187bcacef1aSRiver Riddle   // Erase all of the references to this operation.
1882bdf33ccSRiver Riddle   auto type = op->getResult(0).getType();
189bcacef1aSRiver Riddle   for (auto *dialect : it->second)
190bcacef1aSRiver Riddle     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
191bcacef1aSRiver Riddle   referencedDialects.erase(it);
192bcacef1aSRiver Riddle }
193bcacef1aSRiver Riddle 
1940ddba0bdSRiver Riddle /// Clear out any constants cached inside of the folder.
1950ddba0bdSRiver Riddle void OperationFolder::clear() {
1960ddba0bdSRiver Riddle   foldScopes.clear();
1970ddba0bdSRiver Riddle   referencedDialects.clear();
1980ddba0bdSRiver Riddle }
1990ddba0bdSRiver Riddle 
200152d29ccSRiver Riddle /// Get or create a constant using the given builder. On success this returns
201152d29ccSRiver Riddle /// the constant operation, nullptr otherwise.
2029297b9f8SMatthias Springer Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
20334a65980SBilly Zhu                                            Attribute value, Type type) {
2049297b9f8SMatthias Springer   // Find an insertion point for the constant.
2059297b9f8SMatthias Springer   auto *insertRegion = getInsertionRegion(interfaces, block);
206152d29ccSRiver Riddle   auto &entry = insertRegion->front();
207b613a540SMatthias Springer   rewriter.setInsertionPointToStart(&entry);
208152d29ccSRiver Riddle 
209152d29ccSRiver Riddle   // Get the constant map for the insertion region of this operation.
21034a65980SBilly Zhu   // Use erased location since the op is being built at the front of block.
211152d29ccSRiver Riddle   auto &uniquedConstants = foldScopes[insertRegion];
21234a65980SBilly Zhu   Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
21334a65980SBilly Zhu                                               type, erasedFoldedLocation);
214152d29ccSRiver Riddle   return constOp ? constOp->getResult(0) : Value();
215152d29ccSRiver Riddle }
216152d29ccSRiver Riddle 
217af371f9fSRiver Riddle bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
218af371f9fSRiver Riddle   return referencedDialects.count(op);
219af371f9fSRiver Riddle }
220af371f9fSRiver Riddle 
2219b4a02c1SRiver Riddle /// Tries to perform folding on the given `op`. If successful, populates
2220560f153SRiver Riddle /// `results` with the results of the folding.
2239297b9f8SMatthias Springer LogicalResult OperationFolder::tryToFold(Operation *op,
2249bdfa8dfSMatthias Springer                                          SmallVectorImpl<Value> &results) {
225af371f9fSRiver Riddle   SmallVector<OpFoldResult, 8> foldResults;
22621379151SMatthias Springer   if (failed(op->fold(foldResults)) ||
2279297b9f8SMatthias Springer       failed(processFoldResults(op, results, foldResults)))
228dd115e5aSMatthias Springer     return failure();
229af371f9fSRiver Riddle   return success();
230af371f9fSRiver Riddle }
2311982afb1SRiver Riddle 
2329bdfa8dfSMatthias Springer LogicalResult
2339297b9f8SMatthias Springer OperationFolder::processFoldResults(Operation *op,
2349bdfa8dfSMatthias Springer                                     SmallVectorImpl<Value> &results,
2359bdfa8dfSMatthias Springer                                     ArrayRef<OpFoldResult> foldResults) {
2361982afb1SRiver Riddle   // Check to see if the operation was just updated in place.
2379b4a02c1SRiver Riddle   if (foldResults.empty())
2381982afb1SRiver Riddle     return success();
2399b4a02c1SRiver Riddle   assert(foldResults.size() == op->getNumResults());
2401982afb1SRiver Riddle 
24166ed7d6dSRiver Riddle   // Create a builder to insert new operations into the entry block of the
24266ed7d6dSRiver Riddle   // insertion region.
2439297b9f8SMatthias Springer   auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
2446563b1c4SRiver Riddle   auto &entry = insertRegion->front();
245b613a540SMatthias Springer   rewriter.setInsertionPointToStart(&entry);
24666ed7d6dSRiver Riddle 
24766ed7d6dSRiver Riddle   // Get the constant map for the insertion region of this operation.
2486563b1c4SRiver Riddle   auto &uniquedConstants = foldScopes[insertRegion];
249bcacef1aSRiver Riddle 
2501982afb1SRiver Riddle   // Create the result constants and replace the results.
251bcacef1aSRiver Riddle   auto *dialect = op->getDialect();
2521982afb1SRiver Riddle   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
2539b4a02c1SRiver Riddle     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
2541982afb1SRiver Riddle 
2551982afb1SRiver Riddle     // Check if the result was an SSA value.
25668f58812STres Popp     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
2579b4a02c1SRiver Riddle       results.emplace_back(repl);
2581982afb1SRiver Riddle       continue;
2591982afb1SRiver Riddle     }
2601982afb1SRiver Riddle 
261bcacef1aSRiver Riddle     // Check to see if there is a canonicalized version of this constant.
26235807bc4SRiver Riddle     auto res = op->getResult(i);
263*4f4e2abbSKazu Hirata     Attribute attrRepl = cast<Attribute>(foldResults[i]);
26434a65980SBilly Zhu     if (auto *constOp =
26534a65980SBilly Zhu             tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
26634a65980SBilly Zhu                                    res.getType(), erasedFoldedLocation)) {
267e4635e63SRiver Riddle       // Ensure that this constant dominates the operation we are replacing it
268e4635e63SRiver Riddle       // with. This may not automatically happen if the operation being folded
269e4635e63SRiver Riddle       // was inserted before the constant within the insertion block.
27040a89da6SChris Lattner       Block *opBlock = op->getBlock();
27140a89da6SChris Lattner       if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
27240a89da6SChris Lattner         constOp->moveBefore(&opBlock->front());
273e4635e63SRiver Riddle 
274bcacef1aSRiver Riddle       results.push_back(constOp->getResult(0));
275bcacef1aSRiver Riddle       continue;
2761982afb1SRiver Riddle     }
277bcacef1aSRiver Riddle     // If materialization fails, cleanup any operations generated for the
278bcacef1aSRiver Riddle     // previous results and return failure.
279bcacef1aSRiver Riddle     for (Operation &op : llvm::make_early_inc_range(
2809297b9f8SMatthias Springer              llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
2819297b9f8SMatthias Springer       notifyRemoval(&op);
2829297b9f8SMatthias Springer       rewriter.eraseOp(&op);
2839297b9f8SMatthias Springer     }
2849bdfa8dfSMatthias Springer 
285af371f9fSRiver Riddle     results.clear();
286bcacef1aSRiver Riddle     return failure();
287bcacef1aSRiver Riddle   }
288bcacef1aSRiver Riddle 
2891982afb1SRiver Riddle   return success();
2901982afb1SRiver Riddle }
2911982afb1SRiver Riddle 
292bcacef1aSRiver Riddle /// Try to get or create a new constant entry. On success this returns the
293bcacef1aSRiver Riddle /// constant operation value, nullptr otherwise.
2949297b9f8SMatthias Springer Operation *
2959297b9f8SMatthias Springer OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
2969297b9f8SMatthias Springer                                         Dialect *dialect, Attribute value,
2979297b9f8SMatthias Springer                                         Type type, Location loc) {
298bcacef1aSRiver Riddle   // Check if an existing mapping already exists.
299bcacef1aSRiver Riddle   auto constKey = std::make_tuple(dialect, value, type);
300648f34a2SChris Lattner   Operation *&constOp = uniquedConstants[constKey];
30134a65980SBilly Zhu   if (constOp) {
30234a65980SBilly Zhu     if (loc != constOp->getLoc())
30334a65980SBilly Zhu       constOp->setLoc(erasedFoldedLocation);
304648f34a2SChris Lattner     return constOp;
30534a65980SBilly Zhu   }
3061982afb1SRiver Riddle 
307bcacef1aSRiver Riddle   // If one doesn't exist, try to materialize one.
3089297b9f8SMatthias Springer   if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
309bcacef1aSRiver Riddle     return nullptr;
3101982afb1SRiver Riddle 
311bcacef1aSRiver Riddle   // Check to see if the generated constant is in the expected dialect.
312648f34a2SChris Lattner   auto *newDialect = constOp->getDialect();
313bcacef1aSRiver Riddle   if (newDialect == dialect) {
314648f34a2SChris Lattner     referencedDialects[constOp].push_back(dialect);
315648f34a2SChris Lattner     return constOp;
3161982afb1SRiver Riddle   }
3171982afb1SRiver Riddle 
318bcacef1aSRiver Riddle   // If it isn't, then we also need to make sure that the mapping for the new
319bcacef1aSRiver Riddle   // dialect is valid.
320bcacef1aSRiver Riddle   auto newKey = std::make_tuple(newDialect, value, type);
3211982afb1SRiver Riddle 
322bcacef1aSRiver Riddle   // If an existing operation in the new dialect already exists, delete the
323bcacef1aSRiver Riddle   // materialized operation in favor of the existing one.
324bcacef1aSRiver Riddle   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
3259297b9f8SMatthias Springer     notifyRemoval(constOp);
3269297b9f8SMatthias Springer     rewriter.eraseOp(constOp);
327bcacef1aSRiver Riddle     referencedDialects[existingOp].push_back(dialect);
32834a65980SBilly Zhu     if (loc != existingOp->getLoc())
32934a65980SBilly Zhu       existingOp->setLoc(erasedFoldedLocation);
330648f34a2SChris Lattner     return constOp = existingOp;
3311982afb1SRiver Riddle   }
3321982afb1SRiver Riddle 
333bcacef1aSRiver Riddle   // Otherwise, update the new dialect to the materialized operation.
334648f34a2SChris Lattner   referencedDialects[constOp].assign({dialect, newDialect});
335648f34a2SChris Lattner   auto newIt = uniquedConstants.insert({newKey, constOp});
336bcacef1aSRiver Riddle   return newIt.first->second;
3371982afb1SRiver Riddle }
338