1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines various operation fold utilities. These utilities are 10 // intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/FoldUtils.h" 15 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Operation.h" 19 20 using namespace mlir; 21 22 /// Given an operation, find the parent region that folded constants should be 23 /// inserted into. 24 static Region * 25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, 26 Block *insertionBlock) { 27 while (Region *region = insertionBlock->getParent()) { 28 // Insert in this region for any of the following scenarios: 29 // * The parent is unregistered, or is known to be isolated from above. 30 // * The parent is a top-level operation. 31 auto *parentOp = region->getParentOp(); 32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 33 !parentOp->getBlock()) 34 return region; 35 36 // Otherwise, check if this region is a desired insertion region. 37 auto *interface = interfaces.getInterfaceFor(parentOp); 38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 39 return region; 40 41 // Traverse up the parent looking for an insertion region. 42 insertionBlock = parentOp->getBlock(); 43 } 44 llvm_unreachable("expected valid insertion region"); 45 } 46 47 /// A utility function used to materialize a constant for a given attribute and 48 /// type. On success, a valid constant value is returned. Otherwise, null is 49 /// returned 50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 51 Attribute value, Type type, 52 Location loc) { 53 auto insertPt = builder.getInsertionPoint(); 54 (void)insertPt; 55 56 // Ask the dialect to materialize a constant operation for this value. 57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 58 assert(insertPt == builder.getInsertionPoint()); 59 assert(matchPattern(constOp, m_Constant())); 60 return constOp; 61 } 62 63 return nullptr; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // OperationFolder 68 //===----------------------------------------------------------------------===// 69 70 LogicalResult OperationFolder::tryToFold( 71 Operation *op, function_ref<void(Operation *)> processGeneratedConstants, 72 function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) { 73 if (inPlaceUpdate) 74 *inPlaceUpdate = false; 75 76 // If this is a unique'd constant, return failure as we know that it has 77 // already been folded. 78 if (referencedDialects.count(op)) 79 return failure(); 80 81 // Try to fold the operation. 82 SmallVector<Value, 8> results; 83 OpBuilder builder(op); 84 if (failed(tryToFold(builder, op, results, processGeneratedConstants))) 85 return failure(); 86 87 // Check to see if the operation was just updated in place. 88 if (results.empty()) { 89 if (inPlaceUpdate) 90 *inPlaceUpdate = true; 91 return success(); 92 } 93 94 // Constant folding succeeded. We will start replacing this op's uses and 95 // erase this op. Invoke the callback provided by the caller to perform any 96 // pre-replacement action. 97 if (preReplaceAction) 98 preReplaceAction(op); 99 100 // Replace all of the result values and erase the operation. 101 for (unsigned i = 0, e = results.size(); i != e; ++i) 102 op->getResult(i).replaceAllUsesWith(results[i]); 103 op->erase(); 104 return success(); 105 } 106 107 /// Notifies that the given constant `op` should be remove from this 108 /// OperationFolder's internal bookkeeping. 109 void OperationFolder::notifyRemoval(Operation *op) { 110 // Check to see if this operation is uniqued within the folder. 111 auto it = referencedDialects.find(op); 112 if (it == referencedDialects.end()) 113 return; 114 115 // Get the constant value for this operation, this is the value that was used 116 // to unique the operation internally. 117 Attribute constValue; 118 matchPattern(op, m_Constant(&constValue)); 119 assert(constValue); 120 121 // Get the constant map that this operation was uniqued in. 122 auto &uniquedConstants = 123 foldScopes[getInsertionRegion(interfaces, op->getBlock())]; 124 125 // Erase all of the references to this operation. 126 auto type = op->getResult(0).getType(); 127 for (auto *dialect : it->second) 128 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 129 referencedDialects.erase(it); 130 } 131 132 /// Clear out any constants cached inside of the folder. 133 void OperationFolder::clear() { 134 foldScopes.clear(); 135 referencedDialects.clear(); 136 } 137 138 /// Get or create a constant using the given builder. On success this returns 139 /// the constant operation, nullptr otherwise. 140 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, 141 Attribute value, Type type, 142 Location loc) { 143 OpBuilder::InsertionGuard foldGuard(builder); 144 145 // Use the builder insertion block to find an insertion point for the 146 // constant. 147 auto *insertRegion = 148 getInsertionRegion(interfaces, builder.getInsertionBlock()); 149 auto &entry = insertRegion->front(); 150 builder.setInsertionPoint(&entry, entry.begin()); 151 152 // Get the constant map for the insertion region of this operation. 153 auto &uniquedConstants = foldScopes[insertRegion]; 154 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, 155 builder, value, type, loc); 156 return constOp ? constOp->getResult(0) : Value(); 157 } 158 159 /// Tries to perform folding on the given `op`. If successful, populates 160 /// `results` with the results of the folding. 161 LogicalResult OperationFolder::tryToFold( 162 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 163 function_ref<void(Operation *)> processGeneratedConstants) { 164 SmallVector<Attribute, 8> operandConstants; 165 SmallVector<OpFoldResult, 8> foldResults; 166 167 // If this is a commutative operation, move constants to be trailing operands. 168 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) { 169 std::stable_partition( 170 op->getOpOperands().begin(), op->getOpOperands().end(), 171 [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); }); 172 } 173 174 // Check to see if any operands to the operation is constant and whether 175 // the operation knows how to constant fold itself. 176 operandConstants.assign(op->getNumOperands(), Attribute()); 177 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 178 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 179 180 // Attempt to constant fold the operation. 181 if (failed(op->fold(operandConstants, foldResults))) 182 return failure(); 183 184 // Check to see if the operation was just updated in place. 185 if (foldResults.empty()) 186 return success(); 187 assert(foldResults.size() == op->getNumResults()); 188 189 // Create a builder to insert new operations into the entry block of the 190 // insertion region. 191 auto *insertRegion = 192 getInsertionRegion(interfaces, builder.getInsertionBlock()); 193 auto &entry = insertRegion->front(); 194 OpBuilder::InsertionGuard foldGuard(builder); 195 builder.setInsertionPoint(&entry, entry.begin()); 196 197 // Get the constant map for the insertion region of this operation. 198 auto &uniquedConstants = foldScopes[insertRegion]; 199 200 // Create the result constants and replace the results. 201 auto *dialect = op->getDialect(); 202 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 203 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 204 205 // Check if the result was an SSA value. 206 if (auto repl = foldResults[i].dyn_cast<Value>()) { 207 if (repl.getType() != op->getResult(i).getType()) 208 return failure(); 209 results.emplace_back(repl); 210 continue; 211 } 212 213 // Check to see if there is a canonicalized version of this constant. 214 auto res = op->getResult(i); 215 Attribute attrRepl = foldResults[i].get<Attribute>(); 216 if (auto *constOp = 217 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, 218 res.getType(), op->getLoc())) { 219 // Ensure that this constant dominates the operation we are replacing it 220 // with. This may not automatically happen if the operation being folded 221 // was inserted before the constant within the insertion block. 222 Block *opBlock = op->getBlock(); 223 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) 224 constOp->moveBefore(&opBlock->front()); 225 226 results.push_back(constOp->getResult(0)); 227 continue; 228 } 229 // If materialization fails, cleanup any operations generated for the 230 // previous results and return failure. 231 for (Operation &op : llvm::make_early_inc_range( 232 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 233 notifyRemoval(&op); 234 op.erase(); 235 } 236 return failure(); 237 } 238 239 // Process any newly generated operations. 240 if (processGeneratedConstants) { 241 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 242 processGeneratedConstants(&*i); 243 } 244 245 return success(); 246 } 247 248 /// Try to get or create a new constant entry. On success this returns the 249 /// constant operation value, nullptr otherwise. 250 Operation *OperationFolder::tryGetOrCreateConstant( 251 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, 252 Attribute value, Type type, Location loc) { 253 // Check if an existing mapping already exists. 254 auto constKey = std::make_tuple(dialect, value, type); 255 Operation *&constOp = uniquedConstants[constKey]; 256 if (constOp) 257 return constOp; 258 259 // If one doesn't exist, try to materialize one. 260 if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) 261 return nullptr; 262 263 // Check to see if the generated constant is in the expected dialect. 264 auto *newDialect = constOp->getDialect(); 265 if (newDialect == dialect) { 266 referencedDialects[constOp].push_back(dialect); 267 return constOp; 268 } 269 270 // If it isn't, then we also need to make sure that the mapping for the new 271 // dialect is valid. 272 auto newKey = std::make_tuple(newDialect, value, type); 273 274 // If an existing operation in the new dialect already exists, delete the 275 // materialized operation in favor of the existing one. 276 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 277 constOp->erase(); 278 referencedDialects[existingOp].push_back(dialect); 279 return constOp = existingOp; 280 } 281 282 // Otherwise, update the new dialect to the materialized operation. 283 referencedDialects[constOp].assign({dialect, newDialect}); 284 auto newIt = uniquedConstants.insert({newKey, constOp}); 285 return newIt.first->second; 286 } 287