xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision f1b848e4701a4cd3fa781c259e3728faff1c31df)
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines various operation fold utilities. These utilities are
19 // intended to be used by passes to unify and simply their logic.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/Transforms/FoldUtils.h"
24 
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/StandardOps/Ops.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // OperationFolder
34 //===----------------------------------------------------------------------===//
35 
36 LogicalResult
37 OperationFolder::tryToFold(Operation *op,
38                            std::function<void(Operation *)> preReplaceAction) {
39   assert(op->getFunction() == function &&
40          "cannot constant fold op from another function");
41 
42   // The constant op also implements the constant fold hook; it can be folded
43   // into the value it contains. We need to consider constants before the
44   // constant folding logic to avoid re-creating the same constant later.
45   // TODO: Extend to support dialect-specific constant ops.
46   if (auto constant = dyn_cast<ConstantOp>(op)) {
47     // If this constant is dead, update bookkeeping and signal the caller.
48     if (constant.use_empty()) {
49       notifyRemoval(op);
50       op->erase();
51       return success();
52     }
53     // Otherwise, try to see if we can de-duplicate it.
54     return tryToUnify(op);
55   }
56 
57   // Try to fold the operation.
58   SmallVector<Value *, 8> results;
59   if (failed(tryToFold(op, results)))
60     return failure();
61 
62   // Constant folding succeeded. We will start replacing this op's uses and
63   // eventually erase this op. Invoke the callback provided by the caller to
64   // perform any pre-replacement action.
65   if (preReplaceAction)
66     preReplaceAction(op);
67 
68   // Check to see if the operation was just updated in place.
69   if (results.empty())
70     return success();
71 
72   // Otherwise, replace all of the result values and erase the operation.
73   for (unsigned i = 0, e = results.size(); i != e; ++i)
74     op->getResult(i)->replaceAllUsesWith(results[i]);
75   op->erase();
76   return success();
77 }
78 
79 /// Tries to perform folding on the given `op`. If successful, populates
80 /// `results` with the results of the foldin.
81 LogicalResult OperationFolder::tryToFold(Operation *op,
82                                          SmallVectorImpl<Value *> &results) {
83   assert(op->getFunction() == function &&
84          "cannot constant fold op from another function");
85 
86   SmallVector<Attribute, 8> operandConstants;
87   SmallVector<OpFoldResult, 8> foldResults;
88 
89   // Check to see if any operands to the operation is constant and whether
90   // the operation knows how to constant fold itself.
91   operandConstants.assign(op->getNumOperands(), Attribute());
92   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
93     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
94 
95   // If this is a commutative binary operation with a constant on the left
96   // side move it to the right side.
97   if (operandConstants.size() == 2 && operandConstants[0] &&
98       !operandConstants[1] && op->isCommutative()) {
99     std::swap(op->getOpOperand(0), op->getOpOperand(1));
100     std::swap(operandConstants[0], operandConstants[1]);
101   }
102 
103   // Attempt to constant fold the operation.
104   if (failed(op->fold(operandConstants, foldResults)))
105     return failure();
106 
107   // Check to see if the operation was just updated in place.
108   if (foldResults.empty())
109     return success();
110   assert(foldResults.size() == op->getNumResults());
111 
112   // Create the result constants and replace the results.
113   OpBuilder builder(op);
114   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
115     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
116 
117     // Check if the result was an SSA value.
118     if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
119       results.emplace_back(repl);
120       continue;
121     }
122 
123     // If we already have a canonicalized version of this constant, just reuse
124     // it. Otherwise create a new one.
125     Attribute attrRepl = foldResults[i].get<Attribute>();
126     auto *res = op->getResult(i);
127     auto &constInst =
128         uniquedConstants[std::make_pair(attrRepl, res->getType())];
129     if (!constInst) {
130       // TODO: Extend to support dialect-specific constant ops.
131       auto newOp =
132           builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
133       // Register to the constant map and also move up to entry block to
134       // guarantee dominance.
135       constInst = newOp.getOperation();
136       moveConstantToEntryBlock(constInst);
137     }
138     results.push_back(constInst->getResult(0));
139   }
140 
141   return success();
142 }
143 
144 void OperationFolder::notifyRemoval(Operation *op) {
145   assert(op->getFunction() == function &&
146          "cannot remove constant from another function");
147 
148   Attribute constValue;
149   if (!matchPattern(op, m_Constant(&constValue)))
150     return;
151 
152   // This constant is dead. keep uniquedConstants up to date.
153   auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
154   if (it != uniquedConstants.end() && it->second == op)
155     uniquedConstants.erase(it);
156 }
157 
158 LogicalResult OperationFolder::tryToUnify(Operation *op) {
159   Attribute constValue;
160   matchPattern(op, m_Constant(&constValue));
161   assert(constValue);
162 
163   // Check to see if we already have a constant with this type and value:
164   auto &constInst =
165       uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
166   if (constInst) {
167     // If this constant is already our uniqued one, then leave it alone.
168     if (constInst == op)
169       return failure();
170 
171     // Otherwise replace this redundant constant with the uniqued one.  We know
172     // this is safe because we move constants to the top of the function when
173     // they are uniqued, so we know they dominate all uses.
174     op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
175     op->erase();
176     return success();
177   }
178 
179   // If we have no entry, then we should unique this constant as the
180   // canonical version.  To ensure safe dominance, move the operation to the
181   // entry block of the function.
182   constInst = op;
183   moveConstantToEntryBlock(op);
184   return failure();
185 }
186 
187 void OperationFolder::moveConstantToEntryBlock(Operation *op) {
188   // Insert at the very top of the entry block.
189   auto &entryBB = function->front();
190   op->moveBefore(&entryBB, entryBB.begin());
191 }
192