xref: /llvm-project/mlir/lib/Transforms/Utils/FoldUtils.cpp (revision 1982afb145bcd2a4252b6bd32edd42c52effd9aa)
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines various operation fold utilities. These utilities are
19 // intended to be used by passes to unify and simply their logic.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/Transforms/FoldUtils.h"
24 
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/StandardOps/Ops.h"
29 
30 using namespace mlir;
31 
32 FoldHelper::FoldHelper(Function *f) : function(f) {}
33 
34 LogicalResult
35 FoldHelper::tryToFold(Operation *op,
36                       std::function<void(Operation *)> preReplaceAction) {
37   assert(op->getFunction() == function &&
38          "cannot constant fold op from another function");
39 
40   // The constant op also implements the constant fold hook; it can be folded
41   // into the value it contains. We need to consider constants before the
42   // constant folding logic to avoid re-creating the same constant later.
43   // TODO: Extend to support dialect-specific constant ops.
44   if (auto constant = dyn_cast<ConstantOp>(op)) {
45     // If this constant is dead, update bookkeeping and signal the caller.
46     if (constant.use_empty()) {
47       notifyRemoval(op);
48       op->erase();
49       return success();
50     }
51     // Otherwise, try to see if we can de-duplicate it.
52     return tryToUnify(op);
53   }
54 
55   SmallVector<Attribute, 8> operandConstants;
56   SmallVector<OpFoldResult, 8> results;
57 
58   // Check to see if any operands to the operation is constant and whether
59   // the operation knows how to constant fold itself.
60   operandConstants.assign(op->getNumOperands(), Attribute());
61   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
62     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
63 
64   // If this is a commutative binary operation with a constant on the left
65   // side move it to the right side.
66   if (operandConstants.size() == 2 && operandConstants[0] &&
67       !operandConstants[1] && op->isCommutative()) {
68     std::swap(op->getOpOperand(0), op->getOpOperand(1));
69     std::swap(operandConstants[0], operandConstants[1]);
70   }
71 
72   // Attempt to constant fold the operation.
73   if (failed(op->fold(operandConstants, results)))
74     return failure();
75 
76   // Constant folding succeeded. We will start replacing this op's uses and
77   // eventually erase this op. Invoke the callback provided by the caller to
78   // perform any pre-replacement action.
79   if (preReplaceAction)
80     preReplaceAction(op);
81 
82   // Check to see if the operation was just updated in place.
83   if (results.empty())
84     return success();
85   assert(results.size() == op->getNumResults());
86 
87   // Create the result constants and replace the results.
88   FuncBuilder builder(op);
89   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
90     auto *res = op->getResult(i);
91     if (res->use_empty()) // Ignore dead uses.
92       continue;
93     assert(!results[i].isNull() && "expected valid OpFoldResult");
94 
95     // Check if the result was an SSA value.
96     if (auto *repl = results[i].dyn_cast<Value *>()) {
97       if (repl != res)
98         res->replaceAllUsesWith(repl);
99       continue;
100     }
101 
102     // If we already have a canonicalized version of this constant, just reuse
103     // it.  Otherwise create a new one.
104     Attribute attrRepl = results[i].get<Attribute>();
105     auto &constInst =
106         uniquedConstants[std::make_pair(attrRepl, res->getType())];
107     if (!constInst) {
108       // TODO: Extend to support dialect-specific constant ops.
109       auto newOp =
110           builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
111       // Register to the constant map and also move up to entry block to
112       // guarantee dominance.
113       constInst = newOp.getOperation();
114       moveConstantToEntryBlock(constInst);
115     }
116     res->replaceAllUsesWith(constInst->getResult(0));
117   }
118   op->erase();
119 
120   return success();
121 }
122 
123 void FoldHelper::notifyRemoval(Operation *op) {
124   assert(op->getFunction() == function &&
125          "cannot remove constant from another function");
126 
127   Attribute constValue;
128   if (!matchPattern(op, m_Constant(&constValue)))
129     return;
130 
131   // This constant is dead. keep uniquedConstants up to date.
132   auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
133   if (it != uniquedConstants.end() && it->second == op)
134     uniquedConstants.erase(it);
135 }
136 
137 LogicalResult FoldHelper::tryToUnify(Operation *op) {
138   Attribute constValue;
139   matchPattern(op, m_Constant(&constValue));
140   assert(constValue);
141 
142   // Check to see if we already have a constant with this type and value:
143   auto &constInst =
144       uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
145   if (constInst) {
146     // If this constant is already our uniqued one, then leave it alone.
147     if (constInst == op)
148       return failure();
149 
150     // Otherwise replace this redundant constant with the uniqued one.  We know
151     // this is safe because we move constants to the top of the function when
152     // they are uniqued, so we know they dominate all uses.
153     op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
154     op->erase();
155     return success();
156   }
157 
158   // If we have no entry, then we should unique this constant as the
159   // canonical version.  To ensure safe dominance, move the operation to the
160   // entry block of the function.
161   constInst = op;
162   moveConstantToEntryBlock(op);
163   return failure();
164 }
165 
166 void FoldHelper::moveConstantToEntryBlock(Operation *op) {
167   // Insert at the very top of the entry block.
168   auto &entryBB = function->front();
169   op->moveBefore(&entryBB, entryBB.begin());
170 }
171