1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines various operation fold utilities. These utilities are 10 // intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/FoldUtils.h" 15 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/Operation.h" 20 21 using namespace mlir; 22 23 /// Given an operation, find the parent region that folded constants should be 24 /// inserted into. 25 static Region *getInsertionRegion( 26 DialectInterfaceCollection<OpFolderDialectInterface> &interfaces, 27 Operation *op) { 28 while (Region *region = op->getParentRegion()) { 29 // Insert in this region for any of the following scenarios: 30 // * The parent is unregistered, or is known to be isolated from above. 31 // * The parent is a top-level operation. 32 auto *parentOp = region->getParentOp(); 33 if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() || 34 !parentOp->getBlock()) 35 return region; 36 37 // Otherwise, check if this region is a desired insertion region. 38 auto *interface = interfaces.getInterfaceFor(parentOp); 39 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 40 return region; 41 42 // Traverse up the parent looking for an insertion region. 43 op = parentOp; 44 } 45 llvm_unreachable("expected valid insertion region"); 46 } 47 48 /// A utility function used to materialize a constant for a given attribute and 49 /// type. On success, a valid constant value is returned. Otherwise, null is 50 /// returned 51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 52 Attribute value, Type type, 53 Location loc) { 54 auto insertPt = builder.getInsertionPoint(); 55 (void)insertPt; 56 57 // Ask the dialect to materialize a constant operation for this value. 58 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 59 assert(insertPt == builder.getInsertionPoint()); 60 assert(matchPattern(constOp, m_Constant())); 61 return constOp; 62 } 63 64 // If the dialect is unable to materialize a constant, check to see if the 65 // standard constant can be used. 66 if (ConstantOp::isBuildableWith(value, type)) 67 return builder.create<ConstantOp>(loc, type, value); 68 return nullptr; 69 } 70 71 //===----------------------------------------------------------------------===// 72 // OperationFolder 73 //===----------------------------------------------------------------------===// 74 75 LogicalResult OperationFolder::tryToFold( 76 Operation *op, function_ref<void(Operation *)> processGeneratedConstants, 77 function_ref<void(Operation *)> preReplaceAction) { 78 // If this is a unique'd constant, return failure as we know that it has 79 // already been folded. 80 if (referencedDialects.count(op)) 81 return failure(); 82 83 // Try to fold the operation. 84 SmallVector<Value, 8> results; 85 if (failed(tryToFold(op, results, processGeneratedConstants))) 86 return failure(); 87 88 // Constant folding succeeded. We will start replacing this op's uses and 89 // eventually erase this op. Invoke the callback provided by the caller to 90 // perform any pre-replacement action. 91 if (preReplaceAction) 92 preReplaceAction(op); 93 94 // Check to see if the operation was just updated in place. 95 if (results.empty()) 96 return success(); 97 98 // Otherwise, replace all of the result values and erase the operation. 99 for (unsigned i = 0, e = results.size(); i != e; ++i) 100 op->getResult(i).replaceAllUsesWith(results[i]); 101 op->erase(); 102 return success(); 103 } 104 105 /// Notifies that the given constant `op` should be remove from this 106 /// OperationFolder's internal bookkeeping. 107 void OperationFolder::notifyRemoval(Operation *op) { 108 // Check to see if this operation is uniqued within the folder. 109 auto it = referencedDialects.find(op); 110 if (it == referencedDialects.end()) 111 return; 112 113 // Get the constant value for this operation, this is the value that was used 114 // to unique the operation internally. 115 Attribute constValue; 116 matchPattern(op, m_Constant(&constValue)); 117 assert(constValue); 118 119 // Get the constant map that this operation was uniqued in. 120 auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)]; 121 122 // Erase all of the references to this operation. 123 auto type = op->getResult(0).getType(); 124 for (auto *dialect : it->second) 125 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 126 referencedDialects.erase(it); 127 } 128 129 /// Clear out any constants cached inside of the folder. 130 void OperationFolder::clear() { 131 foldScopes.clear(); 132 referencedDialects.clear(); 133 } 134 135 /// Tries to perform folding on the given `op`. If successful, populates 136 /// `results` with the results of the folding. 137 LogicalResult OperationFolder::tryToFold( 138 Operation *op, SmallVectorImpl<Value> &results, 139 function_ref<void(Operation *)> processGeneratedConstants) { 140 SmallVector<Attribute, 8> operandConstants; 141 SmallVector<OpFoldResult, 8> foldResults; 142 143 // If this is a commutative operation, move constants to be trailing operands. 144 if (op->getNumOperands() >= 2 && op->isCommutative()) { 145 std::stable_partition( 146 op->getOpOperands().begin(), op->getOpOperands().end(), 147 [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); }); 148 } 149 150 // Check to see if any operands to the operation is constant and whether 151 // the operation knows how to constant fold itself. 152 operandConstants.assign(op->getNumOperands(), Attribute()); 153 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 154 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 155 156 // Attempt to constant fold the operation. 157 if (failed(op->fold(operandConstants, foldResults))) 158 return failure(); 159 160 // Check to see if the operation was just updated in place. 161 if (foldResults.empty()) 162 return success(); 163 assert(foldResults.size() == op->getNumResults()); 164 165 // Create a builder to insert new operations into the entry block of the 166 // insertion region. 167 auto *insertRegion = getInsertionRegion(interfaces, op); 168 auto &entry = insertRegion->front(); 169 OpBuilder builder(&entry, entry.begin()); 170 171 // Get the constant map for the insertion region of this operation. 172 auto &uniquedConstants = foldScopes[insertRegion]; 173 174 // Create the result constants and replace the results. 175 auto *dialect = op->getDialect(); 176 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 177 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 178 179 // Check if the result was an SSA value. 180 if (auto repl = foldResults[i].dyn_cast<Value>()) { 181 results.emplace_back(repl); 182 continue; 183 } 184 185 // Check to see if there is a canonicalized version of this constant. 186 auto res = op->getResult(i); 187 Attribute attrRepl = foldResults[i].get<Attribute>(); 188 if (auto *constOp = 189 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, 190 res.getType(), op->getLoc())) { 191 results.push_back(constOp->getResult(0)); 192 continue; 193 } 194 // If materialization fails, cleanup any operations generated for the 195 // previous results and return failure. 196 for (Operation &op : llvm::make_early_inc_range( 197 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 198 notifyRemoval(&op); 199 op.erase(); 200 } 201 return failure(); 202 } 203 204 // Process any newly generated operations. 205 if (processGeneratedConstants) { 206 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 207 processGeneratedConstants(&*i); 208 } 209 210 return success(); 211 } 212 213 /// Try to get or create a new constant entry. On success this returns the 214 /// constant operation value, nullptr otherwise. 215 Operation *OperationFolder::tryGetOrCreateConstant( 216 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, 217 Attribute value, Type type, Location loc) { 218 // Check if an existing mapping already exists. 219 auto constKey = std::make_tuple(dialect, value, type); 220 auto *&constInst = uniquedConstants[constKey]; 221 if (constInst) 222 return constInst; 223 224 // If one doesn't exist, try to materialize one. 225 if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) 226 return nullptr; 227 228 // Check to see if the generated constant is in the expected dialect. 229 auto *newDialect = constInst->getDialect(); 230 if (newDialect == dialect) { 231 referencedDialects[constInst].push_back(dialect); 232 return constInst; 233 } 234 235 // If it isn't, then we also need to make sure that the mapping for the new 236 // dialect is valid. 237 auto newKey = std::make_tuple(newDialect, value, type); 238 239 // If an existing operation in the new dialect already exists, delete the 240 // materialized operation in favor of the existing one. 241 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 242 constInst->erase(); 243 referencedDialects[existingOp].push_back(dialect); 244 return constInst = existingOp; 245 } 246 247 // Otherwise, update the new dialect to the materialized operation. 248 referencedDialects[constInst].assign({dialect, newDialect}); 249 auto newIt = uniquedConstants.insert({newKey, constInst}); 250 return newIt.first->second; 251 } 252