1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines various operation fold utilities. These utilities are 19 // intended to be used by passes to unify and simply their logic. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/FoldUtils.h" 24 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/StandardOps/Ops.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // OperationFolder 34 //===----------------------------------------------------------------------===// 35 36 LogicalResult 37 OperationFolder::tryToFold(Operation *op, 38 std::function<void(Operation *)> preReplaceAction) { 39 assert(op->getFunction() == function && 40 "cannot constant fold op from another function"); 41 42 // The constant op also implements the constant fold hook; it can be folded 43 // into the value it contains. We need to consider constants before the 44 // constant folding logic to avoid re-creating the same constant later. 45 // TODO: Extend to support dialect-specific constant ops. 46 if (auto constant = dyn_cast<ConstantOp>(op)) { 47 // If this constant is dead, update bookkeeping and signal the caller. 48 if (constant.use_empty()) { 49 notifyRemoval(op); 50 op->erase(); 51 return success(); 52 } 53 // Otherwise, try to see if we can de-duplicate it. 54 return tryToUnify(op); 55 } 56 57 // Try to fold the operation. 58 SmallVector<Value *, 8> results; 59 if (failed(tryToFold(op, results))) 60 return failure(); 61 62 // Constant folding succeeded. We will start replacing this op's uses and 63 // eventually erase this op. Invoke the callback provided by the caller to 64 // perform any pre-replacement action. 65 if (preReplaceAction) 66 preReplaceAction(op); 67 68 // Check to see if the operation was just updated in place. 69 if (results.empty()) 70 return success(); 71 72 // Otherwise, replace all of the result values and erase the operation. 73 for (unsigned i = 0, e = results.size(); i != e; ++i) 74 op->getResult(i)->replaceAllUsesWith(results[i]); 75 op->erase(); 76 return success(); 77 } 78 79 /// Tries to perform folding on the given `op`. If successful, populates 80 /// `results` with the results of the foldin. 81 LogicalResult OperationFolder::tryToFold(Operation *op, 82 SmallVectorImpl<Value *> &results) { 83 assert(op->getFunction() == function && 84 "cannot constant fold op from another function"); 85 86 SmallVector<Attribute, 8> operandConstants; 87 SmallVector<OpFoldResult, 8> foldResults; 88 89 // Check to see if any operands to the operation is constant and whether 90 // the operation knows how to constant fold itself. 91 operandConstants.assign(op->getNumOperands(), Attribute()); 92 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 93 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 94 95 // If this is a commutative binary operation with a constant on the left 96 // side move it to the right side. 97 if (operandConstants.size() == 2 && operandConstants[0] && 98 !operandConstants[1] && op->isCommutative()) { 99 std::swap(op->getOpOperand(0), op->getOpOperand(1)); 100 std::swap(operandConstants[0], operandConstants[1]); 101 } 102 103 // Attempt to constant fold the operation. 104 if (failed(op->fold(operandConstants, foldResults))) 105 return failure(); 106 107 // Check to see if the operation was just updated in place. 108 if (foldResults.empty()) 109 return success(); 110 assert(foldResults.size() == op->getNumResults()); 111 112 // Create the result constants and replace the results. 113 OpBuilder builder(op); 114 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 115 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 116 117 // Check if the result was an SSA value. 118 if (auto *repl = foldResults[i].dyn_cast<Value *>()) { 119 results.emplace_back(repl); 120 continue; 121 } 122 123 // If we already have a canonicalized version of this constant, just reuse 124 // it. Otherwise create a new one. 125 Attribute attrRepl = foldResults[i].get<Attribute>(); 126 auto *res = op->getResult(i); 127 auto &constInst = 128 uniquedConstants[std::make_pair(attrRepl, res->getType())]; 129 if (!constInst) { 130 // TODO: Extend to support dialect-specific constant ops. 131 auto newOp = 132 builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl); 133 // Register to the constant map and also move up to entry block to 134 // guarantee dominance. 135 constInst = newOp.getOperation(); 136 moveConstantToEntryBlock(constInst); 137 } 138 results.push_back(constInst->getResult(0)); 139 } 140 141 return success(); 142 } 143 144 void OperationFolder::notifyRemoval(Operation *op) { 145 assert(op->getFunction() == function && 146 "cannot remove constant from another function"); 147 148 Attribute constValue; 149 if (!matchPattern(op, m_Constant(&constValue))) 150 return; 151 152 // This constant is dead. keep uniquedConstants up to date. 153 auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()}); 154 if (it != uniquedConstants.end() && it->second == op) 155 uniquedConstants.erase(it); 156 } 157 158 LogicalResult OperationFolder::tryToUnify(Operation *op) { 159 Attribute constValue; 160 matchPattern(op, m_Constant(&constValue)); 161 assert(constValue); 162 163 // Check to see if we already have a constant with this type and value: 164 auto &constInst = 165 uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())]; 166 if (constInst) { 167 // If this constant is already our uniqued one, then leave it alone. 168 if (constInst == op) 169 return failure(); 170 171 // Otherwise replace this redundant constant with the uniqued one. We know 172 // this is safe because we move constants to the top of the function when 173 // they are uniqued, so we know they dominate all uses. 174 op->getResult(0)->replaceAllUsesWith(constInst->getResult(0)); 175 op->erase(); 176 return success(); 177 } 178 179 // If we have no entry, then we should unique this constant as the 180 // canonical version. To ensure safe dominance, move the operation to the 181 // entry block of the function. 182 constInst = op; 183 moveConstantToEntryBlock(op); 184 return failure(); 185 } 186 187 void OperationFolder::moveConstantToEntryBlock(Operation *op) { 188 // Insert at the very top of the entry block. 189 auto &entryBB = function->front(); 190 op->moveBefore(&entryBB, entryBB.begin()); 191 } 192