xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision bcacef1a70d1405c4e1ee0d7aeaa71b9abb8cdff)
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines various operation fold utilities. These utilities are
19 // intended to be used by passes to unify and simply their logic.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/Transforms/FoldUtils.h"
24 
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/StandardOps/Ops.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // OperationFolder
34 //===----------------------------------------------------------------------===//
35 
36 LogicalResult OperationFolder::tryToFold(
37     Operation *op,
38     llvm::function_ref<void(Operation *)> processGeneratedConstants,
39     llvm::function_ref<void(Operation *)> preReplaceAction) {
40   assert(op->getFunction() == function &&
41          "cannot constant fold op from another function");
42 
43   // If this is a unique'd constant, return failure as we know that it has
44   // already been folded.
45   if (referencedDialects.count(op))
46     return failure();
47 
48   // Try to fold the operation.
49   SmallVector<Value *, 8> results;
50   if (failed(tryToFold(op, results, processGeneratedConstants)))
51     return failure();
52 
53   // Constant folding succeeded. We will start replacing this op's uses and
54   // eventually erase this op. Invoke the callback provided by the caller to
55   // perform any pre-replacement action.
56   if (preReplaceAction)
57     preReplaceAction(op);
58 
59   // Check to see if the operation was just updated in place.
60   if (results.empty())
61     return success();
62 
63   // Otherwise, replace all of the result values and erase the operation.
64   for (unsigned i = 0, e = results.size(); i != e; ++i)
65     op->getResult(i)->replaceAllUsesWith(results[i]);
66   op->erase();
67   return success();
68 }
69 
70 /// Notifies that the given constant `op` should be remove from this
71 /// OperationFolder's internal bookkeeping.
72 void OperationFolder::notifyRemoval(Operation *op) {
73   assert(op->getFunction() == function &&
74          "cannot remove constant from another function");
75 
76   // Check to see if this operation is uniqued within the folder.
77   auto it = referencedDialects.find(op);
78   if (it == referencedDialects.end())
79     return;
80 
81   // Get the constant value for this operation, this is the value that was used
82   // to unique the operation internally.
83   Attribute constValue;
84   matchPattern(op, m_Constant(&constValue));
85   assert(constValue);
86 
87   // Erase all of the references to this operation.
88   auto type = op->getResult(0)->getType();
89   for (auto *dialect : it->second)
90     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
91   referencedDialects.erase(it);
92 }
93 
94 /// A utility function used to materialize a constant for a given attribute and
95 /// type. On success, a valid constant value is returned. Otherwise, null is
96 /// returned
97 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
98                                       Attribute value, Type type,
99                                       Location loc) {
100   auto insertPt = builder.getInsertionPoint();
101   (void)insertPt;
102 
103   // Ask the dialect to materialize a constant operation for this value.
104   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
105     assert(insertPt == builder.getInsertionPoint());
106     assert(matchPattern(constOp, m_Constant(&value)));
107     return constOp;
108   }
109 
110   // If the dialect is unable to materialize a constant, check to see if the
111   // standard constant can be used.
112   if (ConstantOp::isBuildableWith(value, type))
113     return builder.create<ConstantOp>(loc, type, value);
114   return nullptr;
115 }
116 
117 /// Tries to perform folding on the given `op`. If successful, populates
118 /// `results` with the results of the folding.
119 LogicalResult OperationFolder::tryToFold(
120     Operation *op, SmallVectorImpl<Value *> &results,
121     llvm::function_ref<void(Operation *)> processGeneratedConstants) {
122   assert(op->getFunction() == function &&
123          "cannot constant fold op from another function");
124 
125   SmallVector<Attribute, 8> operandConstants;
126   SmallVector<OpFoldResult, 8> foldResults;
127 
128   // Check to see if any operands to the operation is constant and whether
129   // the operation knows how to constant fold itself.
130   operandConstants.assign(op->getNumOperands(), Attribute());
131   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
132     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
133 
134   // If this is a commutative binary operation with a constant on the left
135   // side move it to the right side.
136   if (operandConstants.size() == 2 && operandConstants[0] &&
137       !operandConstants[1] && op->isCommutative()) {
138     std::swap(op->getOpOperand(0), op->getOpOperand(1));
139     std::swap(operandConstants[0], operandConstants[1]);
140   }
141 
142   // Attempt to constant fold the operation.
143   if (failed(op->fold(operandConstants, foldResults)))
144     return failure();
145 
146   // Check to see if the operation was just updated in place.
147   if (foldResults.empty())
148     return success();
149   assert(foldResults.size() == op->getNumResults());
150 
151   // Create a builder to insert new operations into the entry block.
152   auto &entry = function->getBody().front();
153   OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin());
154 
155   // Create the result constants and replace the results.
156   auto *dialect = op->getDialect();
157   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
158     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
159 
160     // Check if the result was an SSA value.
161     if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
162       results.emplace_back(repl);
163       continue;
164     }
165 
166     // Check to see if there is a canonicalized version of this constant.
167     auto *res = op->getResult(i);
168     Attribute attrRepl = foldResults[i].get<Attribute>();
169     if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl,
170                                                res->getType(), op->getLoc())) {
171       results.push_back(constOp->getResult(0));
172       continue;
173     }
174     // If materialization fails, cleanup any operations generated for the
175     // previous results and return failure.
176     for (Operation &op : llvm::make_early_inc_range(
177              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
178       notifyRemoval(&op);
179       op.erase();
180     }
181     return failure();
182   }
183 
184   // Process any newly generated operations.
185   if (processGeneratedConstants) {
186     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
187       processGeneratedConstants(&*i);
188   }
189 
190   return success();
191 }
192 
193 /// Try to get or create a new constant entry. On success this returns the
194 /// constant operation value, nullptr otherwise.
195 Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect,
196                                                    OpBuilder &builder,
197                                                    Attribute value, Type type,
198                                                    Location loc) {
199   // Check if an existing mapping already exists.
200   auto constKey = std::make_tuple(dialect, value, type);
201   auto *&constInst = uniquedConstants[constKey];
202   if (constInst)
203     return constInst;
204 
205   // If one doesn't exist, try to materialize one.
206   if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
207     return nullptr;
208 
209   // Check to see if the generated constant is in the expected dialect.
210   auto *newDialect = constInst->getDialect();
211   if (newDialect == dialect) {
212     referencedDialects[constInst].push_back(dialect);
213     return constInst;
214   }
215 
216   // If it isn't, then we also need to make sure that the mapping for the new
217   // dialect is valid.
218   auto newKey = std::make_tuple(newDialect, value, type);
219 
220   // If an existing operation in the new dialect already exists, delete the
221   // materialized operation in favor of the existing one.
222   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
223     constInst->erase();
224     referencedDialects[existingOp].push_back(dialect);
225     return constInst = existingOp;
226   }
227 
228   // Otherwise, update the new dialect to the materialized operation.
229   referencedDialects[constInst].assign({dialect, newDialect});
230   auto newIt = uniquedConstants.insert({newKey, constInst});
231   return newIt.first->second;
232 }
233