1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines various operation fold utilities. These utilities are 19 // intended to be used by passes to unify and simply their logic. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/FoldUtils.h" 24 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/StandardOps/Ops.h" 29 30 using namespace mlir; 31 32 FoldHelper::FoldHelper(Function *f) : function(f) {} 33 34 LogicalResult 35 FoldHelper::tryToFold(Operation *op, 36 std::function<void(Operation *)> preReplaceAction) { 37 assert(op->getFunction() == function && 38 "cannot constant fold op from another function"); 39 40 // The constant op also implements the constant fold hook; it can be folded 41 // into the value it contains. We need to consider constants before the 42 // constant folding logic to avoid re-creating the same constant later. 43 // TODO: Extend to support dialect-specific constant ops. 44 if (auto constant = dyn_cast<ConstantOp>(op)) { 45 // If this constant is dead, update bookkeeping and signal the caller. 46 if (constant.use_empty()) { 47 notifyRemoval(op); 48 op->erase(); 49 return success(); 50 } 51 // Otherwise, try to see if we can de-duplicate it. 52 return tryToUnify(op); 53 } 54 55 SmallVector<Attribute, 8> operandConstants; 56 SmallVector<OpFoldResult, 8> results; 57 58 // Check to see if any operands to the operation is constant and whether 59 // the operation knows how to constant fold itself. 60 operandConstants.assign(op->getNumOperands(), Attribute()); 61 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 62 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 63 64 // If this is a commutative binary operation with a constant on the left 65 // side move it to the right side. 66 if (operandConstants.size() == 2 && operandConstants[0] && 67 !operandConstants[1] && op->isCommutative()) { 68 std::swap(op->getOpOperand(0), op->getOpOperand(1)); 69 std::swap(operandConstants[0], operandConstants[1]); 70 } 71 72 // Attempt to constant fold the operation. 73 if (failed(op->fold(operandConstants, results))) 74 return failure(); 75 76 // Constant folding succeeded. We will start replacing this op's uses and 77 // eventually erase this op. Invoke the callback provided by the caller to 78 // perform any pre-replacement action. 79 if (preReplaceAction) 80 preReplaceAction(op); 81 82 // Check to see if the operation was just updated in place. 83 if (results.empty()) 84 return success(); 85 assert(results.size() == op->getNumResults()); 86 87 // Create the result constants and replace the results. 88 FuncBuilder builder(op); 89 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 90 auto *res = op->getResult(i); 91 if (res->use_empty()) // Ignore dead uses. 92 continue; 93 assert(!results[i].isNull() && "expected valid OpFoldResult"); 94 95 // Check if the result was an SSA value. 96 if (auto *repl = results[i].dyn_cast<Value *>()) { 97 if (repl != res) 98 res->replaceAllUsesWith(repl); 99 continue; 100 } 101 102 // If we already have a canonicalized version of this constant, just reuse 103 // it. Otherwise create a new one. 104 Attribute attrRepl = results[i].get<Attribute>(); 105 auto &constInst = 106 uniquedConstants[std::make_pair(attrRepl, res->getType())]; 107 if (!constInst) { 108 // TODO: Extend to support dialect-specific constant ops. 109 auto newOp = 110 builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl); 111 // Register to the constant map and also move up to entry block to 112 // guarantee dominance. 113 constInst = newOp.getOperation(); 114 moveConstantToEntryBlock(constInst); 115 } 116 res->replaceAllUsesWith(constInst->getResult(0)); 117 } 118 op->erase(); 119 120 return success(); 121 } 122 123 void FoldHelper::notifyRemoval(Operation *op) { 124 assert(op->getFunction() == function && 125 "cannot remove constant from another function"); 126 127 Attribute constValue; 128 if (!matchPattern(op, m_Constant(&constValue))) 129 return; 130 131 // This constant is dead. keep uniquedConstants up to date. 132 auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()}); 133 if (it != uniquedConstants.end() && it->second == op) 134 uniquedConstants.erase(it); 135 } 136 137 LogicalResult FoldHelper::tryToUnify(Operation *op) { 138 Attribute constValue; 139 matchPattern(op, m_Constant(&constValue)); 140 assert(constValue); 141 142 // Check to see if we already have a constant with this type and value: 143 auto &constInst = 144 uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())]; 145 if (constInst) { 146 // If this constant is already our uniqued one, then leave it alone. 147 if (constInst == op) 148 return failure(); 149 150 // Otherwise replace this redundant constant with the uniqued one. We know 151 // this is safe because we move constants to the top of the function when 152 // they are uniqued, so we know they dominate all uses. 153 op->getResult(0)->replaceAllUsesWith(constInst->getResult(0)); 154 op->erase(); 155 return success(); 156 } 157 158 // If we have no entry, then we should unique this constant as the 159 // canonical version. To ensure safe dominance, move the operation to the 160 // entry block of the function. 161 constInst = op; 162 moveConstantToEntryBlock(op); 163 return failure(); 164 } 165 166 void FoldHelper::moveConstantToEntryBlock(Operation *op) { 167 // Insert at the very top of the entry block. 168 auto &entryBB = function->front(); 169 op->moveBefore(&entryBB, entryBB.begin()); 170 } 171