1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines various operation fold utilities. These utilities are 19 // intended to be used by passes to unify and simply their logic. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/FoldUtils.h" 24 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/StandardOps/Ops.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // OperationFolder 34 //===----------------------------------------------------------------------===// 35 36 LogicalResult OperationFolder::tryToFold( 37 Operation *op, 38 llvm::function_ref<void(Operation *)> processGeneratedConstants, 39 llvm::function_ref<void(Operation *)> preReplaceAction) { 40 assert(op->getFunction() == function && 41 "cannot constant fold op from another function"); 42 43 // If this is a unique'd constant, return failure as we know that it has 44 // already been folded. 45 if (referencedDialects.count(op)) 46 return failure(); 47 48 // Try to fold the operation. 49 SmallVector<Value *, 8> results; 50 if (failed(tryToFold(op, results, processGeneratedConstants))) 51 return failure(); 52 53 // Constant folding succeeded. We will start replacing this op's uses and 54 // eventually erase this op. Invoke the callback provided by the caller to 55 // perform any pre-replacement action. 56 if (preReplaceAction) 57 preReplaceAction(op); 58 59 // Check to see if the operation was just updated in place. 60 if (results.empty()) 61 return success(); 62 63 // Otherwise, replace all of the result values and erase the operation. 64 for (unsigned i = 0, e = results.size(); i != e; ++i) 65 op->getResult(i)->replaceAllUsesWith(results[i]); 66 op->erase(); 67 return success(); 68 } 69 70 /// Notifies that the given constant `op` should be remove from this 71 /// OperationFolder's internal bookkeeping. 72 void OperationFolder::notifyRemoval(Operation *op) { 73 assert(op->getFunction() == function && 74 "cannot remove constant from another function"); 75 76 // Check to see if this operation is uniqued within the folder. 77 auto it = referencedDialects.find(op); 78 if (it == referencedDialects.end()) 79 return; 80 81 // Get the constant value for this operation, this is the value that was used 82 // to unique the operation internally. 83 Attribute constValue; 84 matchPattern(op, m_Constant(&constValue)); 85 assert(constValue); 86 87 // Erase all of the references to this operation. 88 auto type = op->getResult(0)->getType(); 89 for (auto *dialect : it->second) 90 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 91 referencedDialects.erase(it); 92 } 93 94 /// A utility function used to materialize a constant for a given attribute and 95 /// type. On success, a valid constant value is returned. Otherwise, null is 96 /// returned 97 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 98 Attribute value, Type type, 99 Location loc) { 100 auto insertPt = builder.getInsertionPoint(); 101 (void)insertPt; 102 103 // Ask the dialect to materialize a constant operation for this value. 104 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 105 assert(insertPt == builder.getInsertionPoint()); 106 assert(matchPattern(constOp, m_Constant(&value))); 107 return constOp; 108 } 109 110 // If the dialect is unable to materialize a constant, check to see if the 111 // standard constant can be used. 112 if (ConstantOp::isBuildableWith(value, type)) 113 return builder.create<ConstantOp>(loc, type, value); 114 return nullptr; 115 } 116 117 /// Tries to perform folding on the given `op`. If successful, populates 118 /// `results` with the results of the folding. 119 LogicalResult OperationFolder::tryToFold( 120 Operation *op, SmallVectorImpl<Value *> &results, 121 llvm::function_ref<void(Operation *)> processGeneratedConstants) { 122 assert(op->getFunction() == function && 123 "cannot constant fold op from another function"); 124 125 SmallVector<Attribute, 8> operandConstants; 126 SmallVector<OpFoldResult, 8> foldResults; 127 128 // Check to see if any operands to the operation is constant and whether 129 // the operation knows how to constant fold itself. 130 operandConstants.assign(op->getNumOperands(), Attribute()); 131 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 132 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 133 134 // If this is a commutative binary operation with a constant on the left 135 // side move it to the right side. 136 if (operandConstants.size() == 2 && operandConstants[0] && 137 !operandConstants[1] && op->isCommutative()) { 138 std::swap(op->getOpOperand(0), op->getOpOperand(1)); 139 std::swap(operandConstants[0], operandConstants[1]); 140 } 141 142 // Attempt to constant fold the operation. 143 if (failed(op->fold(operandConstants, foldResults))) 144 return failure(); 145 146 // Check to see if the operation was just updated in place. 147 if (foldResults.empty()) 148 return success(); 149 assert(foldResults.size() == op->getNumResults()); 150 151 // Create a builder to insert new operations into the entry block. 152 auto &entry = function->getBody().front(); 153 OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin()); 154 155 // Create the result constants and replace the results. 156 auto *dialect = op->getDialect(); 157 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 158 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 159 160 // Check if the result was an SSA value. 161 if (auto *repl = foldResults[i].dyn_cast<Value *>()) { 162 results.emplace_back(repl); 163 continue; 164 } 165 166 // Check to see if there is a canonicalized version of this constant. 167 auto *res = op->getResult(i); 168 Attribute attrRepl = foldResults[i].get<Attribute>(); 169 if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl, 170 res->getType(), op->getLoc())) { 171 results.push_back(constOp->getResult(0)); 172 continue; 173 } 174 // If materialization fails, cleanup any operations generated for the 175 // previous results and return failure. 176 for (Operation &op : llvm::make_early_inc_range( 177 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 178 notifyRemoval(&op); 179 op.erase(); 180 } 181 return failure(); 182 } 183 184 // Process any newly generated operations. 185 if (processGeneratedConstants) { 186 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 187 processGeneratedConstants(&*i); 188 } 189 190 return success(); 191 } 192 193 /// Try to get or create a new constant entry. On success this returns the 194 /// constant operation value, nullptr otherwise. 195 Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect, 196 OpBuilder &builder, 197 Attribute value, Type type, 198 Location loc) { 199 // Check if an existing mapping already exists. 200 auto constKey = std::make_tuple(dialect, value, type); 201 auto *&constInst = uniquedConstants[constKey]; 202 if (constInst) 203 return constInst; 204 205 // If one doesn't exist, try to materialize one. 206 if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) 207 return nullptr; 208 209 // Check to see if the generated constant is in the expected dialect. 210 auto *newDialect = constInst->getDialect(); 211 if (newDialect == dialect) { 212 referencedDialects[constInst].push_back(dialect); 213 return constInst; 214 } 215 216 // If it isn't, then we also need to make sure that the mapping for the new 217 // dialect is valid. 218 auto newKey = std::make_tuple(newDialect, value, type); 219 220 // If an existing operation in the new dialect already exists, delete the 221 // materialized operation in favor of the existing one. 222 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 223 constInst->erase(); 224 referencedDialects[existingOp].push_back(dialect); 225 return constInst = existingOp; 226 } 227 228 // Otherwise, update the new dialect to the materialized operation. 229 referencedDialects[constInst].assign({dialect, newDialect}); 230 auto newIt = uniquedConstants.insert({newKey, constInst}); 231 return newIt.first->second; 232 } 233