xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision a54f4eae0e1d0ef5adccdcf9f6c2b518dc1101aa)
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/FoldUtils.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26                    Block *insertionBlock) {
27   while (Region *region = insertionBlock->getParent()) {
28     // Insert in this region for any of the following scenarios:
29     //  * The parent is unregistered, or is known to be isolated from above.
30     //  * The parent is a top-level operation.
31     auto *parentOp = region->getParentOp();
32     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33         !parentOp->getBlock())
34       return region;
35 
36     // Otherwise, check if this region is a desired insertion region.
37     auto *interface = interfaces.getInterfaceFor(parentOp);
38     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39       return region;
40 
41     // Traverse up the parent looking for an insertion region.
42     insertionBlock = parentOp->getBlock();
43   }
44   llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51                                       Attribute value, Type type,
52                                       Location loc) {
53   auto insertPt = builder.getInsertionPoint();
54   (void)insertPt;
55 
56   // Ask the dialect to materialize a constant operation for this value.
57   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58     assert(insertPt == builder.getInsertionPoint());
59     assert(matchPattern(constOp, m_Constant()));
60     return constOp;
61   }
62 
63   return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
70 LogicalResult OperationFolder::tryToFold(
71     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
72     function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
73   if (inPlaceUpdate)
74     *inPlaceUpdate = false;
75 
76   // If this is a unique'd constant, return failure as we know that it has
77   // already been folded.
78   if (referencedDialects.count(op))
79     return failure();
80 
81   // Try to fold the operation.
82   SmallVector<Value, 8> results;
83   OpBuilder builder(op);
84   if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
85     return failure();
86 
87   // Check to see if the operation was just updated in place.
88   if (results.empty()) {
89     if (inPlaceUpdate)
90       *inPlaceUpdate = true;
91     return success();
92   }
93 
94   // Constant folding succeeded. We will start replacing this op's uses and
95   // erase this op. Invoke the callback provided by the caller to perform any
96   // pre-replacement action.
97   if (preReplaceAction)
98     preReplaceAction(op);
99 
100   // Replace all of the result values and erase the operation.
101   for (unsigned i = 0, e = results.size(); i != e; ++i)
102     op->getResult(i).replaceAllUsesWith(results[i]);
103   op->erase();
104   return success();
105 }
106 
107 /// Notifies that the given constant `op` should be remove from this
108 /// OperationFolder's internal bookkeeping.
109 void OperationFolder::notifyRemoval(Operation *op) {
110   // Check to see if this operation is uniqued within the folder.
111   auto it = referencedDialects.find(op);
112   if (it == referencedDialects.end())
113     return;
114 
115   // Get the constant value for this operation, this is the value that was used
116   // to unique the operation internally.
117   Attribute constValue;
118   matchPattern(op, m_Constant(&constValue));
119   assert(constValue);
120 
121   // Get the constant map that this operation was uniqued in.
122   auto &uniquedConstants =
123       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
124 
125   // Erase all of the references to this operation.
126   auto type = op->getResult(0).getType();
127   for (auto *dialect : it->second)
128     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
129   referencedDialects.erase(it);
130 }
131 
132 /// Clear out any constants cached inside of the folder.
133 void OperationFolder::clear() {
134   foldScopes.clear();
135   referencedDialects.clear();
136 }
137 
138 /// Get or create a constant using the given builder. On success this returns
139 /// the constant operation, nullptr otherwise.
140 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
141                                            Attribute value, Type type,
142                                            Location loc) {
143   OpBuilder::InsertionGuard foldGuard(builder);
144 
145   // Use the builder insertion block to find an insertion point for the
146   // constant.
147   auto *insertRegion =
148       getInsertionRegion(interfaces, builder.getInsertionBlock());
149   auto &entry = insertRegion->front();
150   builder.setInsertionPoint(&entry, entry.begin());
151 
152   // Get the constant map for the insertion region of this operation.
153   auto &uniquedConstants = foldScopes[insertRegion];
154   Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
155                                               builder, value, type, loc);
156   return constOp ? constOp->getResult(0) : Value();
157 }
158 
159 /// Tries to perform folding on the given `op`. If successful, populates
160 /// `results` with the results of the folding.
161 LogicalResult OperationFolder::tryToFold(
162     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
163     function_ref<void(Operation *)> processGeneratedConstants) {
164   SmallVector<Attribute, 8> operandConstants;
165   SmallVector<OpFoldResult, 8> foldResults;
166 
167   // If this is a commutative operation, move constants to be trailing operands.
168   if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
169     std::stable_partition(
170         op->getOpOperands().begin(), op->getOpOperands().end(),
171         [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
172   }
173 
174   // Check to see if any operands to the operation is constant and whether
175   // the operation knows how to constant fold itself.
176   operandConstants.assign(op->getNumOperands(), Attribute());
177   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
178     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
179 
180   // Attempt to constant fold the operation.
181   if (failed(op->fold(operandConstants, foldResults)))
182     return failure();
183 
184   // Check to see if the operation was just updated in place.
185   if (foldResults.empty())
186     return success();
187   assert(foldResults.size() == op->getNumResults());
188 
189   // Create a builder to insert new operations into the entry block of the
190   // insertion region.
191   auto *insertRegion =
192       getInsertionRegion(interfaces, builder.getInsertionBlock());
193   auto &entry = insertRegion->front();
194   OpBuilder::InsertionGuard foldGuard(builder);
195   builder.setInsertionPoint(&entry, entry.begin());
196 
197   // Get the constant map for the insertion region of this operation.
198   auto &uniquedConstants = foldScopes[insertRegion];
199 
200   // Create the result constants and replace the results.
201   auto *dialect = op->getDialect();
202   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
203     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
204 
205     // Check if the result was an SSA value.
206     if (auto repl = foldResults[i].dyn_cast<Value>()) {
207       if (repl.getType() != op->getResult(i).getType())
208         return failure();
209       results.emplace_back(repl);
210       continue;
211     }
212 
213     // Check to see if there is a canonicalized version of this constant.
214     auto res = op->getResult(i);
215     Attribute attrRepl = foldResults[i].get<Attribute>();
216     if (auto *constOp =
217             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
218                                    res.getType(), op->getLoc())) {
219       // Ensure that this constant dominates the operation we are replacing it
220       // with. This may not automatically happen if the operation being folded
221       // was inserted before the constant within the insertion block.
222       Block *opBlock = op->getBlock();
223       if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
224         constOp->moveBefore(&opBlock->front());
225 
226       results.push_back(constOp->getResult(0));
227       continue;
228     }
229     // If materialization fails, cleanup any operations generated for the
230     // previous results and return failure.
231     for (Operation &op : llvm::make_early_inc_range(
232              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
233       notifyRemoval(&op);
234       op.erase();
235     }
236     return failure();
237   }
238 
239   // Process any newly generated operations.
240   if (processGeneratedConstants) {
241     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
242       processGeneratedConstants(&*i);
243   }
244 
245   return success();
246 }
247 
248 /// Try to get or create a new constant entry. On success this returns the
249 /// constant operation value, nullptr otherwise.
250 Operation *OperationFolder::tryGetOrCreateConstant(
251     ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
252     Attribute value, Type type, Location loc) {
253   // Check if an existing mapping already exists.
254   auto constKey = std::make_tuple(dialect, value, type);
255   Operation *&constOp = uniquedConstants[constKey];
256   if (constOp)
257     return constOp;
258 
259   // If one doesn't exist, try to materialize one.
260   if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
261     return nullptr;
262 
263   // Check to see if the generated constant is in the expected dialect.
264   auto *newDialect = constOp->getDialect();
265   if (newDialect == dialect) {
266     referencedDialects[constOp].push_back(dialect);
267     return constOp;
268   }
269 
270   // If it isn't, then we also need to make sure that the mapping for the new
271   // dialect is valid.
272   auto newKey = std::make_tuple(newDialect, value, type);
273 
274   // If an existing operation in the new dialect already exists, delete the
275   // materialized operation in favor of the existing one.
276   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
277     constOp->erase();
278     referencedDialects[existingOp].push_back(dialect);
279     return constOp = existingOp;
280   }
281 
282   // Otherwise, update the new dialect to the materialized operation.
283   referencedDialects[constOp].assign({dialect, newDialect});
284   auto newIt = uniquedConstants.insert({newKey, constOp});
285   return newIt.first->second;
286 }
287