xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision 68f58812e3e99e31d77c0c23b6298489444dc0be)
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/FoldUtils.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26                    Block *insertionBlock) {
27   while (Region *region = insertionBlock->getParent()) {
28     // Insert in this region for any of the following scenarios:
29     //  * The parent is unregistered, or is known to be isolated from above.
30     //  * The parent is a top-level operation.
31     auto *parentOp = region->getParentOp();
32     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33         !parentOp->getBlock())
34       return region;
35 
36     // Otherwise, check if this region is a desired insertion region.
37     auto *interface = interfaces.getInterfaceFor(parentOp);
38     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39       return region;
40 
41     // Traverse up the parent looking for an insertion region.
42     insertionBlock = parentOp->getBlock();
43   }
44   llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51                                       Attribute value, Type type,
52                                       Location loc) {
53   auto insertPt = builder.getInsertionPoint();
54   (void)insertPt;
55 
56   // Ask the dialect to materialize a constant operation for this value.
57   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58     assert(insertPt == builder.getInsertionPoint());
59     assert(matchPattern(constOp, m_Constant()));
60     return constOp;
61   }
62 
63   return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
70 LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71   if (inPlaceUpdate)
72     *inPlaceUpdate = false;
73 
74   // If this is a unique'd constant, return failure as we know that it has
75   // already been folded.
76   if (isFolderOwnedConstant(op)) {
77     // Check to see if we should rehoist, i.e. if a non-constant operation was
78     // inserted before this one.
79     Block *opBlock = op->getBlock();
80     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
81       op->moveBefore(&opBlock->front());
82     return failure();
83   }
84 
85   // Try to fold the operation.
86   SmallVector<Value, 8> results;
87   if (failed(tryToFold(op, results)))
88     return failure();
89 
90   // Check to see if the operation was just updated in place.
91   if (results.empty()) {
92     if (inPlaceUpdate)
93       *inPlaceUpdate = true;
94     if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
95             rewriter.getListener())) {
96       // Folding API does not notify listeners, so we have to notify manually.
97       rewriteListener->notifyOperationModified(op);
98     }
99     return success();
100   }
101 
102   // Constant folding succeeded. Replace all of the result values and erase the
103   // operation.
104   notifyRemoval(op);
105   rewriter.replaceOp(op, results);
106   return success();
107 }
108 
109 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
110   Block *opBlock = op->getBlock();
111 
112   // If this is a constant we unique'd, we don't need to insert, but we can
113   // check to see if we should rehoist it.
114   if (isFolderOwnedConstant(op)) {
115     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
116       op->moveBefore(&opBlock->front());
117     return true;
118   }
119 
120   // Get the constant value of the op if necessary.
121   if (!constValue) {
122     matchPattern(op, m_Constant(&constValue));
123     assert(constValue && "expected `op` to be a constant");
124   } else {
125     // Ensure that the provided constant was actually correct.
126 #ifndef NDEBUG
127     Attribute expectedValue;
128     matchPattern(op, m_Constant(&expectedValue));
129     assert(
130         expectedValue == constValue &&
131         "provided constant value was not the expected value of the constant");
132 #endif
133   }
134 
135   // Check for an existing constant operation for the attribute value.
136   Region *insertRegion = getInsertionRegion(interfaces, opBlock);
137   auto &uniquedConstants = foldScopes[insertRegion];
138   Operation *&folderConstOp = uniquedConstants[std::make_tuple(
139       op->getDialect(), constValue, *op->result_type_begin())];
140 
141   // If there is an existing constant, replace `op`.
142   if (folderConstOp) {
143     notifyRemoval(op);
144     rewriter.replaceOp(op, folderConstOp->getResults());
145     return false;
146   }
147 
148   // Otherwise, we insert `op`. If `op` is in the insertion block and is either
149   // already at the front of the block, or the previous operation is already a
150   // constant we unique'd (i.e. one we inserted), then we don't need to do
151   // anything. Otherwise, we move the constant to the insertion block.
152   Block *insertBlock = &insertRegion->front();
153   if (opBlock != insertBlock || (&insertBlock->front() != op &&
154                                  !isFolderOwnedConstant(op->getPrevNode())))
155     op->moveBefore(&insertBlock->front());
156 
157   folderConstOp = op;
158   referencedDialects[op].push_back(op->getDialect());
159   return true;
160 }
161 
162 /// Notifies that the given constant `op` should be remove from this
163 /// OperationFolder's internal bookkeeping.
164 void OperationFolder::notifyRemoval(Operation *op) {
165   // Check to see if this operation is uniqued within the folder.
166   auto it = referencedDialects.find(op);
167   if (it == referencedDialects.end())
168     return;
169 
170   // Get the constant value for this operation, this is the value that was used
171   // to unique the operation internally.
172   Attribute constValue;
173   matchPattern(op, m_Constant(&constValue));
174   assert(constValue);
175 
176   // Get the constant map that this operation was uniqued in.
177   auto &uniquedConstants =
178       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
179 
180   // Erase all of the references to this operation.
181   auto type = op->getResult(0).getType();
182   for (auto *dialect : it->second)
183     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
184   referencedDialects.erase(it);
185 }
186 
187 /// Clear out any constants cached inside of the folder.
188 void OperationFolder::clear() {
189   foldScopes.clear();
190   referencedDialects.clear();
191 }
192 
193 /// Get or create a constant using the given builder. On success this returns
194 /// the constant operation, nullptr otherwise.
195 Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
196                                            Attribute value, Type type,
197                                            Location loc) {
198   // Find an insertion point for the constant.
199   auto *insertRegion = getInsertionRegion(interfaces, block);
200   auto &entry = insertRegion->front();
201   rewriter.setInsertionPoint(&entry, entry.begin());
202 
203   // Get the constant map for the insertion region of this operation.
204   auto &uniquedConstants = foldScopes[insertRegion];
205   Operation *constOp =
206       tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc);
207   return constOp ? constOp->getResult(0) : Value();
208 }
209 
210 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
211   return referencedDialects.count(op);
212 }
213 
214 /// Tries to perform folding on the given `op`. If successful, populates
215 /// `results` with the results of the folding.
216 LogicalResult OperationFolder::tryToFold(Operation *op,
217                                          SmallVectorImpl<Value> &results) {
218   SmallVector<Attribute, 8> operandConstants;
219 
220   // If this is a commutative operation, move constants to be trailing operands.
221   bool updatedOpOperands = false;
222   if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
223     auto isNonConstant = [&](OpOperand &o) {
224       return !matchPattern(o.get(), m_Constant());
225     };
226     auto *firstConstantIt =
227         llvm::find_if_not(op->getOpOperands(), isNonConstant);
228     auto *newConstantIt = std::stable_partition(
229         firstConstantIt, op->getOpOperands().end(), isNonConstant);
230 
231     // Remember if we actually moved anything.
232     updatedOpOperands = firstConstantIt != newConstantIt;
233   }
234 
235   // Check to see if any operands to the operation is constant and whether
236   // the operation knows how to constant fold itself.
237   operandConstants.assign(op->getNumOperands(), Attribute());
238   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
239     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
240 
241   // Attempt to constant fold the operation. If we failed, check to see if we at
242   // least updated the operands of the operation. We treat this as an in-place
243   // fold.
244   SmallVector<OpFoldResult, 8> foldResults;
245   if (failed(op->fold(operandConstants, foldResults)) ||
246       failed(processFoldResults(op, results, foldResults)))
247     return success(updatedOpOperands);
248   return success();
249 }
250 
251 LogicalResult
252 OperationFolder::processFoldResults(Operation *op,
253                                     SmallVectorImpl<Value> &results,
254                                     ArrayRef<OpFoldResult> foldResults) {
255   // Check to see if the operation was just updated in place.
256   if (foldResults.empty())
257     return success();
258   assert(foldResults.size() == op->getNumResults());
259 
260   // Create a builder to insert new operations into the entry block of the
261   // insertion region.
262   auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
263   auto &entry = insertRegion->front();
264   rewriter.setInsertionPoint(&entry, entry.begin());
265 
266   // Get the constant map for the insertion region of this operation.
267   auto &uniquedConstants = foldScopes[insertRegion];
268 
269   // Create the result constants and replace the results.
270   auto *dialect = op->getDialect();
271   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
272     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
273 
274     // Check if the result was an SSA value.
275     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
276       if (repl.getType() != op->getResult(i).getType()) {
277         results.clear();
278         return failure();
279       }
280       results.emplace_back(repl);
281       continue;
282     }
283 
284     // Check to see if there is a canonicalized version of this constant.
285     auto res = op->getResult(i);
286     Attribute attrRepl = foldResults[i].get<Attribute>();
287     if (auto *constOp = tryGetOrCreateConstant(
288             uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
289       // Ensure that this constant dominates the operation we are replacing it
290       // with. This may not automatically happen if the operation being folded
291       // was inserted before the constant within the insertion block.
292       Block *opBlock = op->getBlock();
293       if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
294         constOp->moveBefore(&opBlock->front());
295 
296       results.push_back(constOp->getResult(0));
297       continue;
298     }
299     // If materialization fails, cleanup any operations generated for the
300     // previous results and return failure.
301     for (Operation &op : llvm::make_early_inc_range(
302              llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
303       notifyRemoval(&op);
304       rewriter.eraseOp(&op);
305     }
306 
307     results.clear();
308     return failure();
309   }
310 
311   return success();
312 }
313 
314 /// Try to get or create a new constant entry. On success this returns the
315 /// constant operation value, nullptr otherwise.
316 Operation *
317 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
318                                         Dialect *dialect, Attribute value,
319                                         Type type, Location loc) {
320   // Check if an existing mapping already exists.
321   auto constKey = std::make_tuple(dialect, value, type);
322   Operation *&constOp = uniquedConstants[constKey];
323   if (constOp)
324     return constOp;
325 
326   // If one doesn't exist, try to materialize one.
327   if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
328     return nullptr;
329 
330   // Check to see if the generated constant is in the expected dialect.
331   auto *newDialect = constOp->getDialect();
332   if (newDialect == dialect) {
333     referencedDialects[constOp].push_back(dialect);
334     return constOp;
335   }
336 
337   // If it isn't, then we also need to make sure that the mapping for the new
338   // dialect is valid.
339   auto newKey = std::make_tuple(newDialect, value, type);
340 
341   // If an existing operation in the new dialect already exists, delete the
342   // materialized operation in favor of the existing one.
343   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
344     notifyRemoval(constOp);
345     rewriter.eraseOp(constOp);
346     referencedDialects[existingOp].push_back(dialect);
347     return constOp = existingOp;
348   }
349 
350   // Otherwise, update the new dialect to the materialized operation.
351   referencedDialects[constOp].assign({dialect, newDialect});
352   auto newIt = uniquedConstants.insert({newKey, constOp});
353   return newIt.first->second;
354 }
355