1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines various operation fold utilities. These utilities are 10 // intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/FoldUtils.h" 15 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Operation.h" 19 20 using namespace mlir; 21 22 /// Given an operation, find the parent region that folded constants should be 23 /// inserted into. 24 static Region * 25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, 26 Block *insertionBlock) { 27 while (Region *region = insertionBlock->getParent()) { 28 // Insert in this region for any of the following scenarios: 29 // * The parent is unregistered, or is known to be isolated from above. 30 // * The parent is a top-level operation. 31 auto *parentOp = region->getParentOp(); 32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 33 !parentOp->getBlock()) 34 return region; 35 36 // Otherwise, check if this region is a desired insertion region. 37 auto *interface = interfaces.getInterfaceFor(parentOp); 38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 39 return region; 40 41 // Traverse up the parent looking for an insertion region. 42 insertionBlock = parentOp->getBlock(); 43 } 44 llvm_unreachable("expected valid insertion region"); 45 } 46 47 /// A utility function used to materialize a constant for a given attribute and 48 /// type. On success, a valid constant value is returned. Otherwise, null is 49 /// returned 50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 51 Attribute value, Type type, 52 Location loc) { 53 auto insertPt = builder.getInsertionPoint(); 54 (void)insertPt; 55 56 // Ask the dialect to materialize a constant operation for this value. 57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 58 assert(insertPt == builder.getInsertionPoint()); 59 assert(matchPattern(constOp, m_Constant())); 60 return constOp; 61 } 62 63 return nullptr; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // OperationFolder 68 //===----------------------------------------------------------------------===// 69 70 LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { 71 if (inPlaceUpdate) 72 *inPlaceUpdate = false; 73 74 // If this is a unique'd constant, return failure as we know that it has 75 // already been folded. 76 if (isFolderOwnedConstant(op)) { 77 // Check to see if we should rehoist, i.e. if a non-constant operation was 78 // inserted before this one. 79 Block *opBlock = op->getBlock(); 80 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) 81 op->moveBefore(&opBlock->front()); 82 return failure(); 83 } 84 85 // Try to fold the operation. 86 SmallVector<Value, 8> results; 87 if (failed(tryToFold(op, results))) 88 return failure(); 89 90 // Check to see if the operation was just updated in place. 91 if (results.empty()) { 92 if (inPlaceUpdate) 93 *inPlaceUpdate = true; 94 if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>( 95 rewriter.getListener())) { 96 // Folding API does not notify listeners, so we have to notify manually. 97 rewriteListener->notifyOperationModified(op); 98 } 99 return success(); 100 } 101 102 // Constant folding succeeded. Replace all of the result values and erase the 103 // operation. 104 notifyRemoval(op); 105 rewriter.replaceOp(op, results); 106 return success(); 107 } 108 109 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { 110 Block *opBlock = op->getBlock(); 111 112 // If this is a constant we unique'd, we don't need to insert, but we can 113 // check to see if we should rehoist it. 114 if (isFolderOwnedConstant(op)) { 115 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) 116 op->moveBefore(&opBlock->front()); 117 return true; 118 } 119 120 // Get the constant value of the op if necessary. 121 if (!constValue) { 122 matchPattern(op, m_Constant(&constValue)); 123 assert(constValue && "expected `op` to be a constant"); 124 } else { 125 // Ensure that the provided constant was actually correct. 126 #ifndef NDEBUG 127 Attribute expectedValue; 128 matchPattern(op, m_Constant(&expectedValue)); 129 assert( 130 expectedValue == constValue && 131 "provided constant value was not the expected value of the constant"); 132 #endif 133 } 134 135 // Check for an existing constant operation for the attribute value. 136 Region *insertRegion = getInsertionRegion(interfaces, opBlock); 137 auto &uniquedConstants = foldScopes[insertRegion]; 138 Operation *&folderConstOp = uniquedConstants[std::make_tuple( 139 op->getDialect(), constValue, *op->result_type_begin())]; 140 141 // If there is an existing constant, replace `op`. 142 if (folderConstOp) { 143 notifyRemoval(op); 144 rewriter.replaceOp(op, folderConstOp->getResults()); 145 return false; 146 } 147 148 // Otherwise, we insert `op`. If `op` is in the insertion block and is either 149 // already at the front of the block, or the previous operation is already a 150 // constant we unique'd (i.e. one we inserted), then we don't need to do 151 // anything. Otherwise, we move the constant to the insertion block. 152 Block *insertBlock = &insertRegion->front(); 153 if (opBlock != insertBlock || (&insertBlock->front() != op && 154 !isFolderOwnedConstant(op->getPrevNode()))) 155 op->moveBefore(&insertBlock->front()); 156 157 folderConstOp = op; 158 referencedDialects[op].push_back(op->getDialect()); 159 return true; 160 } 161 162 /// Notifies that the given constant `op` should be remove from this 163 /// OperationFolder's internal bookkeeping. 164 void OperationFolder::notifyRemoval(Operation *op) { 165 // Check to see if this operation is uniqued within the folder. 166 auto it = referencedDialects.find(op); 167 if (it == referencedDialects.end()) 168 return; 169 170 // Get the constant value for this operation, this is the value that was used 171 // to unique the operation internally. 172 Attribute constValue; 173 matchPattern(op, m_Constant(&constValue)); 174 assert(constValue); 175 176 // Get the constant map that this operation was uniqued in. 177 auto &uniquedConstants = 178 foldScopes[getInsertionRegion(interfaces, op->getBlock())]; 179 180 // Erase all of the references to this operation. 181 auto type = op->getResult(0).getType(); 182 for (auto *dialect : it->second) 183 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 184 referencedDialects.erase(it); 185 } 186 187 /// Clear out any constants cached inside of the folder. 188 void OperationFolder::clear() { 189 foldScopes.clear(); 190 referencedDialects.clear(); 191 } 192 193 /// Get or create a constant using the given builder. On success this returns 194 /// the constant operation, nullptr otherwise. 195 Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, 196 Attribute value, Type type, 197 Location loc) { 198 // Find an insertion point for the constant. 199 auto *insertRegion = getInsertionRegion(interfaces, block); 200 auto &entry = insertRegion->front(); 201 rewriter.setInsertionPoint(&entry, entry.begin()); 202 203 // Get the constant map for the insertion region of this operation. 204 auto &uniquedConstants = foldScopes[insertRegion]; 205 Operation *constOp = 206 tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc); 207 return constOp ? constOp->getResult(0) : Value(); 208 } 209 210 bool OperationFolder::isFolderOwnedConstant(Operation *op) const { 211 return referencedDialects.count(op); 212 } 213 214 /// Tries to perform folding on the given `op`. If successful, populates 215 /// `results` with the results of the folding. 216 LogicalResult OperationFolder::tryToFold(Operation *op, 217 SmallVectorImpl<Value> &results) { 218 SmallVector<Attribute, 8> operandConstants; 219 220 // If this is a commutative operation, move constants to be trailing operands. 221 bool updatedOpOperands = false; 222 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) { 223 auto isNonConstant = [&](OpOperand &o) { 224 return !matchPattern(o.get(), m_Constant()); 225 }; 226 auto *firstConstantIt = 227 llvm::find_if_not(op->getOpOperands(), isNonConstant); 228 auto *newConstantIt = std::stable_partition( 229 firstConstantIt, op->getOpOperands().end(), isNonConstant); 230 231 // Remember if we actually moved anything. 232 updatedOpOperands = firstConstantIt != newConstantIt; 233 } 234 235 // Check to see if any operands to the operation is constant and whether 236 // the operation knows how to constant fold itself. 237 operandConstants.assign(op->getNumOperands(), Attribute()); 238 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 239 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 240 241 // Attempt to constant fold the operation. If we failed, check to see if we at 242 // least updated the operands of the operation. We treat this as an in-place 243 // fold. 244 SmallVector<OpFoldResult, 8> foldResults; 245 if (failed(op->fold(operandConstants, foldResults)) || 246 failed(processFoldResults(op, results, foldResults))) 247 return success(updatedOpOperands); 248 return success(); 249 } 250 251 LogicalResult 252 OperationFolder::processFoldResults(Operation *op, 253 SmallVectorImpl<Value> &results, 254 ArrayRef<OpFoldResult> foldResults) { 255 // Check to see if the operation was just updated in place. 256 if (foldResults.empty()) 257 return success(); 258 assert(foldResults.size() == op->getNumResults()); 259 260 // Create a builder to insert new operations into the entry block of the 261 // insertion region. 262 auto *insertRegion = getInsertionRegion(interfaces, op->getBlock()); 263 auto &entry = insertRegion->front(); 264 rewriter.setInsertionPoint(&entry, entry.begin()); 265 266 // Get the constant map for the insertion region of this operation. 267 auto &uniquedConstants = foldScopes[insertRegion]; 268 269 // Create the result constants and replace the results. 270 auto *dialect = op->getDialect(); 271 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 272 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 273 274 // Check if the result was an SSA value. 275 if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) { 276 if (repl.getType() != op->getResult(i).getType()) { 277 results.clear(); 278 return failure(); 279 } 280 results.emplace_back(repl); 281 continue; 282 } 283 284 // Check to see if there is a canonicalized version of this constant. 285 auto res = op->getResult(i); 286 Attribute attrRepl = foldResults[i].get<Attribute>(); 287 if (auto *constOp = tryGetOrCreateConstant( 288 uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) { 289 // Ensure that this constant dominates the operation we are replacing it 290 // with. This may not automatically happen if the operation being folded 291 // was inserted before the constant within the insertion block. 292 Block *opBlock = op->getBlock(); 293 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) 294 constOp->moveBefore(&opBlock->front()); 295 296 results.push_back(constOp->getResult(0)); 297 continue; 298 } 299 // If materialization fails, cleanup any operations generated for the 300 // previous results and return failure. 301 for (Operation &op : llvm::make_early_inc_range( 302 llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) { 303 notifyRemoval(&op); 304 rewriter.eraseOp(&op); 305 } 306 307 results.clear(); 308 return failure(); 309 } 310 311 return success(); 312 } 313 314 /// Try to get or create a new constant entry. On success this returns the 315 /// constant operation value, nullptr otherwise. 316 Operation * 317 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, 318 Dialect *dialect, Attribute value, 319 Type type, Location loc) { 320 // Check if an existing mapping already exists. 321 auto constKey = std::make_tuple(dialect, value, type); 322 Operation *&constOp = uniquedConstants[constKey]; 323 if (constOp) 324 return constOp; 325 326 // If one doesn't exist, try to materialize one. 327 if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc))) 328 return nullptr; 329 330 // Check to see if the generated constant is in the expected dialect. 331 auto *newDialect = constOp->getDialect(); 332 if (newDialect == dialect) { 333 referencedDialects[constOp].push_back(dialect); 334 return constOp; 335 } 336 337 // If it isn't, then we also need to make sure that the mapping for the new 338 // dialect is valid. 339 auto newKey = std::make_tuple(newDialect, value, type); 340 341 // If an existing operation in the new dialect already exists, delete the 342 // materialized operation in favor of the existing one. 343 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 344 notifyRemoval(constOp); 345 rewriter.eraseOp(constOp); 346 referencedDialects[existingOp].push_back(dialect); 347 return constOp = existingOp; 348 } 349 350 // Otherwise, update the new dialect to the materialized operation. 351 referencedDialects[constOp].assign({dialect, newDialect}); 352 auto newIt = uniquedConstants.insert({newKey, constOp}); 353 return newIt.first->second; 354 } 355