1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous loop transformation routines. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Utils/Utils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/OpDefinition.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "mlir/Transforms/RegionUtils.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallPtrSet.h" 29 #include "llvm/ADT/SmallVector.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/MathExtras.h" 32 #include <cstdint> 33 34 using namespace mlir; 35 36 #define DEBUG_TYPE "scf-utils" 37 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 38 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 39 40 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( 41 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, 42 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, 43 bool replaceIterOperandsUsesInLoop) { 44 if (loopNest.empty()) 45 return {}; 46 // This method is recursive (to make it more readable). Adding an 47 // assertion here to limit the recursion. (See 48 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235) 49 assert(loopNest.size() <= 10 && 50 "exceeded recursion limit when yielding value from loop nest"); 51 52 // To yield a value from a perfectly nested loop nest, the following 53 // pattern needs to be created, i.e. starting with 54 // 55 // ```mlir 56 // scf.for .. { 57 // scf.for .. { 58 // scf.for .. { 59 // %value = ... 60 // } 61 // } 62 // } 63 // ``` 64 // 65 // needs to be modified to 66 // 67 // ```mlir 68 // %0 = scf.for .. iter_args(%arg0 = %init) { 69 // %1 = scf.for .. iter_args(%arg1 = %arg0) { 70 // %2 = scf.for .. iter_args(%arg2 = %arg1) { 71 // %value = ... 72 // scf.yield %value 73 // } 74 // scf.yield %2 75 // } 76 // scf.yield %1 77 // } 78 // ``` 79 // 80 // The inner most loop is handled using the `replaceWithAdditionalYields` 81 // that works on a single loop. 82 if (loopNest.size() == 1) { 83 auto innerMostLoop = 84 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields( 85 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, 86 newYieldValuesFn)); 87 return {innerMostLoop}; 88 } 89 // The outer loops are modified by calling this method recursively 90 // - The return value of the inner loop is the value yielded by this loop. 91 // - The region iter args of this loop are the init_args for the inner loop. 92 SmallVector<scf::ForOp> newLoopNest; 93 NewYieldValuesFn fn = 94 [&](OpBuilder &innerBuilder, Location loc, 95 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { 96 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(), 97 innerNewBBArgs, newYieldValuesFn, 98 replaceIterOperandsUsesInLoop); 99 return llvm::to_vector(llvm::map_range( 100 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()), 101 [](OpResult r) -> Value { return r; })); 102 }; 103 scf::ForOp outerMostLoop = 104 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields( 105 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn)); 106 newLoopNest.insert(newLoopNest.begin(), outerMostLoop); 107 return newLoopNest; 108 } 109 110 /// Outline a region with a single block into a new FuncOp. 111 /// Assumes the FuncOp result types is the type of the yielded operands of the 112 /// single block. This constraint makes it easy to determine the result. 113 /// This method also clones the `arith::ConstantIndexOp` at the start of 114 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is 115 /// provided, it will be set to point to the operation that calls the outlined 116 /// function. 117 // TODO: support more than single-block regions. 118 // TODO: more flexible constant handling. 119 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, 120 Location loc, 121 Region ®ion, 122 StringRef funcName, 123 func::CallOp *callOp) { 124 assert(!funcName.empty() && "funcName cannot be empty"); 125 if (!region.hasOneBlock()) 126 return failure(); 127 128 Block *originalBlock = ®ion.front(); 129 Operation *originalTerminator = originalBlock->getTerminator(); 130 131 // Outline before current function. 132 OpBuilder::InsertionGuard g(rewriter); 133 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>()); 134 135 SetVector<Value> captures; 136 getUsedValuesDefinedAbove(region, captures); 137 138 ValueRange outlinedValues(captures.getArrayRef()); 139 SmallVector<Type> outlinedFuncArgTypes; 140 SmallVector<Location> outlinedFuncArgLocs; 141 // Region's arguments are exactly the first block's arguments as per 142 // Region::getArguments(). 143 // Func's arguments are cat(regions's arguments, captures arguments). 144 for (BlockArgument arg : region.getArguments()) { 145 outlinedFuncArgTypes.push_back(arg.getType()); 146 outlinedFuncArgLocs.push_back(arg.getLoc()); 147 } 148 for (Value value : outlinedValues) { 149 outlinedFuncArgTypes.push_back(value.getType()); 150 outlinedFuncArgLocs.push_back(value.getLoc()); 151 } 152 FunctionType outlinedFuncType = 153 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, 154 originalTerminator->getOperandTypes()); 155 auto outlinedFunc = 156 rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType); 157 Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); 158 159 // Merge blocks while replacing the original block operands. 160 // Warning: `mergeBlocks` erases the original block, reconstruct it later. 161 int64_t numOriginalBlockArguments = originalBlock->getNumArguments(); 162 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments(); 163 { 164 OpBuilder::InsertionGuard g(rewriter); 165 rewriter.setInsertionPointToEnd(outlinedFuncBody); 166 rewriter.mergeBlocks( 167 originalBlock, outlinedFuncBody, 168 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); 169 // Explicitly set up a new ReturnOp terminator. 170 rewriter.setInsertionPointToEnd(outlinedFuncBody); 171 rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(), 172 originalTerminator->getOperands()); 173 } 174 175 // Reconstruct the block that was deleted and add a 176 // terminator(call_results). 177 Block *newBlock = rewriter.createBlock( 178 ®ion, region.begin(), 179 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments), 180 ArrayRef<Location>(outlinedFuncArgLocs) 181 .take_front(numOriginalBlockArguments)); 182 { 183 OpBuilder::InsertionGuard g(rewriter); 184 rewriter.setInsertionPointToEnd(newBlock); 185 SmallVector<Value> callValues; 186 llvm::append_range(callValues, newBlock->getArguments()); 187 llvm::append_range(callValues, outlinedValues); 188 auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues); 189 if (callOp) 190 *callOp = call; 191 192 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. 193 // Clone `originalTerminator` to take the callOp results then erase it from 194 // `outlinedFuncBody`. 195 IRMapping bvm; 196 bvm.map(originalTerminator->getOperands(), call->getResults()); 197 rewriter.clone(*originalTerminator, bvm); 198 rewriter.eraseOp(originalTerminator); 199 } 200 201 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`. 202 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`. 203 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back( 204 outlinedValues.size()))) { 205 Value orig = std::get<0>(it); 206 Value repl = std::get<1>(it); 207 { 208 OpBuilder::InsertionGuard g(rewriter); 209 rewriter.setInsertionPointToStart(outlinedFuncBody); 210 if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) { 211 IRMapping bvm; 212 repl = rewriter.clone(*cst, bvm)->getResult(0); 213 } 214 } 215 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) { 216 return outlinedFunc->isProperAncestor(opOperand.getOwner()); 217 }); 218 } 219 220 return outlinedFunc; 221 } 222 223 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, 224 func::FuncOp *thenFn, StringRef thenFnName, 225 func::FuncOp *elseFn, StringRef elseFnName) { 226 IRRewriter rewriter(b); 227 Location loc = ifOp.getLoc(); 228 FailureOr<func::FuncOp> outlinedFuncOpOrFailure; 229 if (thenFn && !ifOp.getThenRegion().empty()) { 230 outlinedFuncOpOrFailure = outlineSingleBlockRegion( 231 rewriter, loc, ifOp.getThenRegion(), thenFnName); 232 if (failed(outlinedFuncOpOrFailure)) 233 return failure(); 234 *thenFn = *outlinedFuncOpOrFailure; 235 } 236 if (elseFn && !ifOp.getElseRegion().empty()) { 237 outlinedFuncOpOrFailure = outlineSingleBlockRegion( 238 rewriter, loc, ifOp.getElseRegion(), elseFnName); 239 if (failed(outlinedFuncOpOrFailure)) 240 return failure(); 241 *elseFn = *outlinedFuncOpOrFailure; 242 } 243 return success(); 244 } 245 246 bool mlir::getInnermostParallelLoops(Operation *rootOp, 247 SmallVectorImpl<scf::ParallelOp> &result) { 248 assert(rootOp != nullptr && "Root operation must not be a nullptr."); 249 bool rootEnclosesPloops = false; 250 for (Region ®ion : rootOp->getRegions()) { 251 for (Block &block : region.getBlocks()) { 252 for (Operation &op : block) { 253 bool enclosesPloops = getInnermostParallelLoops(&op, result); 254 rootEnclosesPloops |= enclosesPloops; 255 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) { 256 rootEnclosesPloops = true; 257 258 // Collect parallel loop if it is an innermost one. 259 if (!enclosesPloops) 260 result.push_back(ploop); 261 } 262 } 263 } 264 } 265 return rootEnclosesPloops; 266 } 267 268 // Build the IR that performs ceil division of a positive value by a constant: 269 // ceildiv(a, B) = divis(a + (B-1), B) 270 // where divis is rounding-to-zero division. 271 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, 272 int64_t divisor) { 273 assert(divisor > 0 && "expected positive divisor"); 274 assert(dividend.getType().isIntOrIndex() && 275 "expected integer or index-typed value"); 276 277 Value divisorMinusOneCst = builder.create<arith::ConstantOp>( 278 loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); 279 Value divisorCst = builder.create<arith::ConstantOp>( 280 loc, builder.getIntegerAttr(dividend.getType(), divisor)); 281 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst); 282 return builder.create<arith::DivUIOp>(loc, sum, divisorCst); 283 } 284 285 // Build the IR that performs ceil division of a positive value by another 286 // positive value: 287 // ceildiv(a, b) = divis(a + (b - 1), b) 288 // where divis is rounding-to-zero division. 289 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, 290 Value divisor) { 291 assert(dividend.getType().isIntOrIndex() && 292 "expected integer or index-typed value"); 293 Value cstOne = builder.create<arith::ConstantOp>( 294 loc, builder.getOneAttr(dividend.getType())); 295 Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne); 296 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne); 297 return builder.create<arith::DivUIOp>(loc, sum, divisor); 298 } 299 300 /// Returns the trip count of `forOp` if its' low bound, high bound and step are 301 /// constants, or optional otherwise. Trip count is computed as 302 /// ceilDiv(highBound - lowBound, step). 303 static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) { 304 std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); 305 std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); 306 std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); 307 if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) 308 return {}; 309 310 // Constant loop bounds computation. 311 int64_t lbCst = lbCstOp.value(); 312 int64_t ubCst = ubCstOp.value(); 313 int64_t stepCst = stepCstOp.value(); 314 assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 && 315 "expected positive loop bounds and step"); 316 return llvm::divideCeilSigned(ubCst - lbCst, stepCst); 317 } 318 319 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with 320 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap 321 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each 322 /// unrolled iteration using annotateFn. 323 static void generateUnrolledLoop( 324 Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, 325 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, 326 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, 327 ValueRange iterArgs, ValueRange yieldedValues) { 328 // Builder to insert unrolled bodies just before the terminator of the body of 329 // 'forOp'. 330 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); 331 332 constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; 333 if (!annotateFn) 334 annotateFn = defaultAnnotateFn; 335 336 // Keep a pointer to the last non-terminator operation in the original block 337 // so that we know what to clone (since we are doing this in-place). 338 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); 339 340 // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). 341 SmallVector<Value, 4> lastYielded(yieldedValues); 342 343 for (unsigned i = 1; i < unrollFactor; i++) { 344 IRMapping operandMap; 345 346 // Prepare operand map. 347 operandMap.map(iterArgs, lastYielded); 348 349 // If the induction variable is used, create a remapping to the value for 350 // this unrolled instance. 351 if (!forOpIV.use_empty()) { 352 Value ivUnroll = ivRemapFn(i, forOpIV, builder); 353 operandMap.map(forOpIV, ivUnroll); 354 } 355 356 // Clone the original body of 'forOp'. 357 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) { 358 Operation *clonedOp = builder.clone(*it, operandMap); 359 annotateFn(i, clonedOp, builder); 360 } 361 362 // Update yielded values. 363 for (unsigned i = 0, e = lastYielded.size(); i < e; i++) 364 lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]); 365 } 366 367 // Make sure we annotate the Ops in the original body. We do this last so that 368 // any annotations are not copied into the cloned Ops above. 369 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) 370 annotateFn(0, &*it, builder); 371 372 // Update operands of the yield statement. 373 loopBodyBlock->getTerminator()->setOperands(lastYielded); 374 } 375 376 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the 377 /// eplilog loop, if the loop is unrolled. 378 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( 379 scf::ForOp forOp, uint64_t unrollFactor, 380 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) { 381 assert(unrollFactor > 0 && "expected positive unroll factor"); 382 383 // Return if the loop body is empty. 384 if (llvm::hasSingleElement(forOp.getBody()->getOperations())) 385 return UnrolledLoopInfo{forOp, std::nullopt}; 386 387 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate 388 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. 389 OpBuilder boundsBuilder(forOp); 390 IRRewriter rewriter(forOp.getContext()); 391 auto loc = forOp.getLoc(); 392 Value step = forOp.getStep(); 393 Value upperBoundUnrolled; 394 Value stepUnrolled; 395 bool generateEpilogueLoop = true; 396 397 std::optional<int64_t> constTripCount = getConstantTripCount(forOp); 398 if (constTripCount) { 399 // Constant loop bounds computation. 400 int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); 401 int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value(); 402 int64_t stepCst = getConstantIntValue(forOp.getStep()).value(); 403 if (unrollFactor == 1) { 404 if (*constTripCount == 1 && 405 failed(forOp.promoteIfSingleIteration(rewriter))) 406 return failure(); 407 return UnrolledLoopInfo{forOp, std::nullopt}; 408 } 409 410 int64_t tripCountEvenMultiple = 411 *constTripCount - (*constTripCount % unrollFactor); 412 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; 413 int64_t stepUnrolledCst = stepCst * unrollFactor; 414 415 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. 416 generateEpilogueLoop = upperBoundUnrolledCst < ubCst; 417 if (generateEpilogueLoop) 418 upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>( 419 loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), 420 upperBoundUnrolledCst)); 421 else 422 upperBoundUnrolled = forOp.getUpperBound(); 423 424 // Create constant for 'stepUnrolled'. 425 stepUnrolled = stepCst == stepUnrolledCst 426 ? step 427 : boundsBuilder.create<arith::ConstantOp>( 428 loc, boundsBuilder.getIntegerAttr( 429 step.getType(), stepUnrolledCst)); 430 } else { 431 // Dynamic loop bounds computation. 432 // TODO: Add dynamic asserts for negative lb/ub/step, or 433 // consider using ceilDiv from AffineApplyExpander. 434 auto lowerBound = forOp.getLowerBound(); 435 auto upperBound = forOp.getUpperBound(); 436 Value diff = 437 boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); 438 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); 439 Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>( 440 loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); 441 Value tripCountRem = 442 boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst); 443 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) 444 Value tripCountEvenMultiple = 445 boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem); 446 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step 447 upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>( 448 loc, lowerBound, 449 boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step)); 450 // Scale 'step' by 'unrollFactor'. 451 stepUnrolled = 452 boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst); 453 } 454 455 UnrolledLoopInfo resultLoops; 456 457 // Create epilogue clean up loop starting at 'upperBoundUnrolled'. 458 if (generateEpilogueLoop) { 459 OpBuilder epilogueBuilder(forOp->getContext()); 460 epilogueBuilder.setInsertionPointAfter(forOp); 461 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp)); 462 epilogueForOp.setLowerBound(upperBoundUnrolled); 463 464 // Update uses of loop results. 465 auto results = forOp.getResults(); 466 auto epilogueResults = epilogueForOp.getResults(); 467 468 for (auto e : llvm::zip(results, epilogueResults)) { 469 std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); 470 } 471 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), 472 epilogueForOp.getInitArgs().size(), results); 473 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed()) 474 resultLoops.epilogueLoopOp = epilogueForOp; 475 } 476 477 // Create unrolled loop. 478 forOp.setUpperBound(upperBoundUnrolled); 479 forOp.setStep(stepUnrolled); 480 481 auto iterArgs = ValueRange(forOp.getRegionIterArgs()); 482 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); 483 484 generateUnrolledLoop( 485 forOp.getBody(), forOp.getInductionVar(), unrollFactor, 486 [&](unsigned i, Value iv, OpBuilder b) { 487 // iv' = iv + step * i; 488 auto stride = b.create<arith::MulIOp>( 489 loc, step, 490 b.create<arith::ConstantOp>(loc, 491 b.getIntegerAttr(iv.getType(), i))); 492 return b.create<arith::AddIOp>(loc, iv, stride); 493 }, 494 annotateFn, iterArgs, yieldedValues); 495 // Promote the loop body up if this has turned into a single iteration loop. 496 if (forOp.promoteIfSingleIteration(rewriter).failed()) 497 resultLoops.mainLoopOp = forOp; 498 return resultLoops; 499 } 500 501 /// Check if bounds of all inner loops are defined outside of `forOp` 502 /// and return false if not. 503 static bool areInnerBoundsInvariant(scf::ForOp forOp) { 504 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) { 505 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) || 506 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) || 507 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep())) 508 return WalkResult::interrupt(); 509 510 return WalkResult::advance(); 511 }); 512 return !walkResult.wasInterrupted(); 513 } 514 515 /// Unrolls and jams this loop by the specified factor. 516 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, 517 uint64_t unrollJamFactor) { 518 assert(unrollJamFactor > 0 && "unroll jam factor should be positive"); 519 520 if (unrollJamFactor == 1) 521 return success(); 522 523 // If any control operand of any inner loop of `forOp` is defined within 524 // `forOp`, no unroll jam. 525 if (!areInnerBoundsInvariant(forOp)) { 526 LDBG("failed to unroll and jam: inner bounds are not invariant"); 527 return failure(); 528 } 529 530 // Currently, for operations with results are not supported. 531 if (forOp->getNumResults() > 0) { 532 LDBG("failed to unroll and jam: unsupported loop with results"); 533 return failure(); 534 } 535 536 // Currently, only constant trip count that divided by the unroll factor is 537 // supported. 538 std::optional<uint64_t> tripCount = getConstantTripCount(forOp); 539 if (!tripCount.has_value()) { 540 // If the trip count is dynamic, do not unroll & jam. 541 LDBG("failed to unroll and jam: trip count could not be determined"); 542 return failure(); 543 } 544 if (unrollJamFactor > *tripCount) { 545 LDBG("unroll and jam factor is greater than trip count, set factor to trip " 546 "count"); 547 unrollJamFactor = *tripCount; 548 } else if (*tripCount % unrollJamFactor != 0) { 549 LDBG("failed to unroll and jam: unsupported trip count that is not a " 550 "multiple of unroll jam factor"); 551 return failure(); 552 } 553 554 // Nothing in the loop body other than the terminator. 555 if (llvm::hasSingleElement(forOp.getBody()->getOperations())) 556 return success(); 557 558 // Gather all sub-blocks to jam upon the loop being unrolled. 559 JamBlockGatherer<scf::ForOp> jbg; 560 jbg.walk(forOp); 561 auto &subBlocks = jbg.subBlocks; 562 563 // Collect inner loops. 564 SmallVector<scf::ForOp> innerLoops; 565 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); }); 566 567 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled 568 // iteration. There are (`unrollJamFactor` - 1) iterations. 569 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1); 570 571 // For any loop with iter_args, replace it with a new loop that has 572 // `unrollJamFactor` copies of its iterOperands, iter_args and yield 573 // operands. 574 SmallVector<scf::ForOp> newInnerLoops; 575 IRRewriter rewriter(forOp.getContext()); 576 for (scf::ForOp oldForOp : innerLoops) { 577 SmallVector<Value> dupIterOperands, dupYieldOperands; 578 ValueRange oldIterOperands = oldForOp.getInits(); 579 ValueRange oldIterArgs = oldForOp.getRegionIterArgs(); 580 ValueRange oldYieldOperands = 581 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands(); 582 // Get additional iterOperands, iterArgs, and yield operands. We will 583 // fix iterOperands and yield operands after cloning of sub-blocks. 584 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { 585 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end()); 586 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end()); 587 } 588 // Create a new loop with additional iterOperands, iter_args and yield 589 // operands. This new loop will take the loop body of the original loop. 590 bool forOpReplaced = oldForOp == forOp; 591 scf::ForOp newForOp = 592 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields( 593 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false, 594 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) { 595 return dupYieldOperands; 596 })); 597 newInnerLoops.push_back(newForOp); 598 // `forOp` has been replaced with a new loop. 599 if (forOpReplaced) 600 forOp = newForOp; 601 // Update `operandMaps` for `newForOp` iterArgs and results. 602 ValueRange newIterArgs = newForOp.getRegionIterArgs(); 603 unsigned oldNumIterArgs = oldIterArgs.size(); 604 ValueRange newResults = newForOp.getResults(); 605 unsigned oldNumResults = newResults.size() / unrollJamFactor; 606 assert(oldNumIterArgs == oldNumResults && 607 "oldNumIterArgs must be the same as oldNumResults"); 608 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { 609 for (unsigned j = 0; j < oldNumIterArgs; ++j) { 610 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and 611 // results. Update `operandMaps[i - 1]` to map old iterArgs and results 612 // to those in the `i`th new set. 613 operandMaps[i - 1].map(newIterArgs[j], 614 newIterArgs[i * oldNumIterArgs + j]); 615 operandMaps[i - 1].map(newResults[j], 616 newResults[i * oldNumResults + j]); 617 } 618 } 619 } 620 621 // Scale the step of loop being unroll-jammed by the unroll-jam factor. 622 rewriter.setInsertionPoint(forOp); 623 int64_t step = forOp.getConstantStep()->getSExtValue(); 624 auto newStep = rewriter.createOrFold<arith::MulIOp>( 625 forOp.getLoc(), forOp.getStep(), 626 rewriter.createOrFold<arith::ConstantOp>( 627 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor))); 628 forOp.setStep(newStep); 629 auto forOpIV = forOp.getInductionVar(); 630 631 // Unroll and jam (appends unrollJamFactor - 1 additional copies). 632 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { 633 for (auto &subBlock : subBlocks) { 634 // Builder to insert unroll-jammed bodies. Insert right at the end of 635 // sub-block. 636 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); 637 638 // If the induction variable is used, create a remapping to the value for 639 // this unrolled instance. 640 if (!forOpIV.use_empty()) { 641 // iv' = iv + i * step, i = 1 to unrollJamFactor-1. 642 auto ivTag = builder.createOrFold<arith::ConstantOp>( 643 forOp.getLoc(), builder.getIndexAttr(step * i)); 644 auto ivUnroll = 645 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag); 646 operandMaps[i - 1].map(forOpIV, ivUnroll); 647 } 648 // Clone the sub-block being unroll-jammed. 649 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) 650 builder.clone(*it, operandMaps[i - 1]); 651 } 652 // Fix iterOperands and yield op operands of newly created loops. 653 for (auto newForOp : newInnerLoops) { 654 unsigned oldNumIterOperands = 655 newForOp.getNumRegionIterArgs() / unrollJamFactor; 656 unsigned numControlOperands = newForOp.getNumControlOperands(); 657 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator()); 658 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor; 659 assert(oldNumIterOperands == oldNumYieldOperands && 660 "oldNumIterOperands must be the same as oldNumYieldOperands"); 661 for (unsigned j = 0; j < oldNumIterOperands; ++j) { 662 // The `i`th duplication of an old iterOperand or yield op operand 663 // needs to be replaced with a mapped value from `operandMaps[i - 1]` 664 // if such mapped value exists. 665 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j, 666 operandMaps[i - 1].lookupOrDefault( 667 newForOp.getOperand(numControlOperands + j))); 668 yieldOp.setOperand( 669 i * oldNumYieldOperands + j, 670 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j))); 671 } 672 } 673 } 674 675 // Promote the loop body up if this has turned into a single iteration loop. 676 (void)forOp.promoteIfSingleIteration(rewriter); 677 return success(); 678 } 679 680 Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, 681 OpFoldResult lb, OpFoldResult ub, 682 OpFoldResult step) { 683 Range normalizedLoopBounds; 684 normalizedLoopBounds.offset = rewriter.getIndexAttr(0); 685 normalizedLoopBounds.stride = rewriter.getIndexAttr(1); 686 AffineExpr s0, s1, s2; 687 bindSymbols(rewriter.getContext(), s0, s1, s2); 688 AffineExpr e = (s1 - s0).ceilDiv(s2); 689 normalizedLoopBounds.size = 690 affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step}); 691 return normalizedLoopBounds; 692 } 693 694 Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, 695 OpFoldResult lb, OpFoldResult ub, 696 OpFoldResult step) { 697 if (getType(lb).isIndex()) { 698 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step); 699 } 700 // For non-index types, generate `arith` instructions 701 // Check if the loop is already known to have a constant zero lower bound or 702 // a constant one step. 703 bool isZeroBased = false; 704 if (auto lbCst = getConstantIntValue(lb)) 705 isZeroBased = lbCst.value() == 0; 706 707 bool isStepOne = false; 708 if (auto stepCst = getConstantIntValue(step)) 709 isStepOne = stepCst.value() == 1; 710 711 Type rangeType = getType(lb); 712 assert(rangeType == getType(ub) && rangeType == getType(step) && 713 "expected matching types"); 714 715 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) 716 // assuming the step is strictly positive. Update the bounds and the step 717 // of the loop to go from 0 to the number of iterations, if necessary. 718 if (isZeroBased && isStepOne) 719 return {lb, ub, step}; 720 721 OpFoldResult diff = ub; 722 if (!isZeroBased) { 723 diff = rewriter.createOrFold<arith::SubIOp>( 724 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub), 725 getValueOrCreateConstantIntOp(rewriter, loc, lb)); 726 } 727 OpFoldResult newUpperBound = diff; 728 if (!isStepOne) { 729 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>( 730 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff), 731 getValueOrCreateConstantIntOp(rewriter, loc, step)); 732 } 733 734 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType); 735 OpFoldResult newStep = rewriter.getOneAttr(rangeType); 736 737 return {newLowerBound, newUpperBound, newStep}; 738 } 739 740 static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, 741 Location loc, 742 Value normalizedIv, 743 OpFoldResult origLb, 744 OpFoldResult origStep) { 745 AffineExpr d0, s0, s1; 746 bindSymbols(rewriter.getContext(), s0, s1); 747 bindDims(rewriter.getContext(), d0); 748 AffineExpr e = d0 * s1 + s0; 749 OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply( 750 rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep}); 751 Value denormalizedIvVal = 752 getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv); 753 SmallPtrSet<Operation *, 1> preservedUses; 754 // If an `affine.apply` operation is generated for denormalization, the use 755 // of `origLb` in those ops must not be replaced. These arent not generated 756 // when `origLb == 0` and `origStep == 1`. 757 if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) { 758 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) { 759 preservedUses.insert(preservedUse); 760 } 761 } 762 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses); 763 } 764 765 void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, 766 Value normalizedIv, OpFoldResult origLb, 767 OpFoldResult origStep) { 768 if (getType(origLb).isIndex()) { 769 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv, 770 origLb, origStep); 771 } 772 Value denormalizedIv; 773 SmallPtrSet<Operation *, 2> preserve; 774 bool isStepOne = isConstantIntValue(origStep, 1); 775 bool isZeroBased = isConstantIntValue(origLb, 0); 776 777 Value scaled = normalizedIv; 778 if (!isStepOne) { 779 Value origStepValue = 780 getValueOrCreateConstantIntOp(rewriter, loc, origStep); 781 scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue); 782 preserve.insert(scaled.getDefiningOp()); 783 } 784 denormalizedIv = scaled; 785 if (!isZeroBased) { 786 Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb); 787 denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue); 788 preserve.insert(denormalizedIv.getDefiningOp()); 789 } 790 791 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve); 792 } 793 794 static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, 795 ArrayRef<OpFoldResult> values) { 796 assert(!values.empty() && "unexecpted empty array"); 797 AffineExpr s0, s1; 798 bindSymbols(rewriter.getContext(), s0, s1); 799 AffineExpr mul = s0 * s1; 800 OpFoldResult products = rewriter.getIndexAttr(1); 801 for (auto v : values) { 802 products = affine::makeComposedFoldedAffineApply( 803 rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v}); 804 } 805 return products; 806 } 807 808 /// Helper function to multiply a sequence of values. 809 static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, 810 ArrayRef<Value> values) { 811 assert(!values.empty() && "unexpected empty list"); 812 if (getType(values.front()).isIndex()) { 813 SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values); 814 OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs); 815 return getValueOrCreateConstantIndexOp(rewriter, loc, product); 816 } 817 std::optional<Value> productOf; 818 for (auto v : values) { 819 auto vOne = getConstantIntValue(v); 820 if (vOne && vOne.value() == 1) 821 continue; 822 if (productOf) 823 productOf = 824 rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult(); 825 else 826 productOf = v; 827 } 828 if (!productOf) { 829 productOf = rewriter 830 .create<arith::ConstantOp>( 831 loc, rewriter.getOneAttr(getType(values.front()))) 832 .getResult(); 833 } 834 return productOf.value(); 835 } 836 837 /// For each original loop, the value of the 838 /// induction variable can be obtained by dividing the induction variable of 839 /// the linearized loop by the total number of iterations of the loops nested 840 /// in it modulo the number of iterations in this loop (remove the values 841 /// related to the outer loops): 842 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. 843 /// Compute these iteratively from the innermost loop by creating a "running 844 /// quotient" of division by the range. 845 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>> 846 delinearizeInductionVariable(RewriterBase &rewriter, Location loc, 847 Value linearizedIv, ArrayRef<Value> ubs) { 848 849 if (linearizedIv.getType().isIndex()) { 850 Operation *delinearizedOp = 851 rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv, 852 ubs); 853 auto resultVals = llvm::map_to_vector( 854 delinearizedOp->getResults(), [](OpResult r) -> Value { return r; }); 855 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}}; 856 } 857 858 SmallVector<Value> delinearizedIvs(ubs.size()); 859 SmallPtrSet<Operation *, 2> preservedUsers; 860 861 llvm::BitVector isUbOne(ubs.size()); 862 for (auto [index, ub] : llvm::enumerate(ubs)) { 863 auto ubCst = getConstantIntValue(ub); 864 if (ubCst && ubCst.value() == 1) 865 isUbOne.set(index); 866 } 867 868 // Prune the lead ubs that are all ones. 869 unsigned numLeadingOneUbs = 0; 870 for (auto [index, ub] : llvm::enumerate(ubs)) { 871 if (!isUbOne.test(index)) { 872 break; 873 } 874 delinearizedIvs[index] = rewriter.create<arith::ConstantOp>( 875 loc, rewriter.getZeroAttr(ub.getType())); 876 numLeadingOneUbs++; 877 } 878 879 Value previous = linearizedIv; 880 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) { 881 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1; 882 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) { 883 previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]); 884 preservedUsers.insert(previous.getDefiningOp()); 885 } 886 Value iv = previous; 887 if (i != e - 1) { 888 if (!isUbOne.test(idx)) { 889 iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]); 890 preservedUsers.insert(iv.getDefiningOp()); 891 } else { 892 iv = rewriter.create<arith::ConstantOp>( 893 loc, rewriter.getZeroAttr(ubs[idx].getType())); 894 } 895 } 896 delinearizedIvs[idx] = iv; 897 } 898 return {delinearizedIvs, preservedUsers}; 899 } 900 901 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter, 902 MutableArrayRef<scf::ForOp> loops) { 903 if (loops.size() < 2) 904 return failure(); 905 906 scf::ForOp innermost = loops.back(); 907 scf::ForOp outermost = loops.front(); 908 909 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This 910 // allows the following code to assume upperBound is the number of iterations. 911 for (auto loop : loops) { 912 OpBuilder::InsertionGuard g(rewriter); 913 rewriter.setInsertionPoint(outermost); 914 Value lb = loop.getLowerBound(); 915 Value ub = loop.getUpperBound(); 916 Value step = loop.getStep(); 917 auto newLoopRange = 918 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step); 919 920 rewriter.modifyOpInPlace(loop, [&]() { 921 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), 922 newLoopRange.offset)); 923 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), 924 newLoopRange.size)); 925 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), 926 newLoopRange.stride)); 927 }); 928 rewriter.setInsertionPointToStart(innermost.getBody()); 929 denormalizeInductionVariable(rewriter, loop.getLoc(), 930 loop.getInductionVar(), lb, step); 931 } 932 933 // 2. Emit code computing the upper bound of the coalesced loop as product 934 // of the number of iterations of all loops. 935 OpBuilder::InsertionGuard g(rewriter); 936 rewriter.setInsertionPoint(outermost); 937 Location loc = outermost.getLoc(); 938 SmallVector<Value> upperBounds = llvm::map_to_vector( 939 loops, [](auto loop) { return loop.getUpperBound(); }); 940 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds); 941 outermost.setUpperBound(upperBound); 942 943 rewriter.setInsertionPointToStart(innermost.getBody()); 944 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable( 945 rewriter, loc, outermost.getInductionVar(), upperBounds); 946 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0], 947 preservedUsers); 948 949 for (int i = loops.size() - 1; i > 0; --i) { 950 auto outerLoop = loops[i - 1]; 951 auto innerLoop = loops[i]; 952 953 Operation *innerTerminator = innerLoop.getBody()->getTerminator(); 954 auto yieldedVals = llvm::to_vector(innerTerminator->getOperands()); 955 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs())); 956 for (Value &yieldedVal : yieldedVals) { 957 // The yielded value may be an iteration argument of the inner loop 958 // which is about to be inlined. 959 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal); 960 if (iter != innerLoop.getRegionIterArgs().end()) { 961 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin(); 962 // `outerLoop` iter args identical to the `innerLoop` init args. 963 assert(iterArgIndex < innerLoop.getInitArgs().size()); 964 yieldedVal = innerLoop.getInitArgs()[iterArgIndex]; 965 } 966 } 967 rewriter.eraseOp(innerTerminator); 968 969 SmallVector<Value> innerBlockArgs; 970 innerBlockArgs.push_back(delinearizeIvs[i]); 971 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs()); 972 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(), 973 Block::iterator(innerLoop), innerBlockArgs); 974 rewriter.replaceOp(innerLoop, yieldedVals); 975 } 976 return success(); 977 } 978 979 LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) { 980 if (loops.empty()) { 981 return failure(); 982 } 983 IRRewriter rewriter(loops.front().getContext()); 984 return coalesceLoops(rewriter, loops); 985 } 986 987 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) { 988 LogicalResult result(failure()); 989 SmallVector<scf::ForOp> loops; 990 getPerfectlyNestedLoops(loops, op); 991 992 // Look for a band of loops that can be coalesced, i.e. perfectly nested 993 // loops with bounds defined above some loop. 994 995 // 1. For each loop, find above which parent loop its bounds operands are 996 // defined. 997 SmallVector<unsigned> operandsDefinedAbove(loops.size()); 998 for (unsigned i = 0, e = loops.size(); i < e; ++i) { 999 operandsDefinedAbove[i] = i; 1000 for (unsigned j = 0; j < i; ++j) { 1001 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(), 1002 loops[i].getUpperBound(), 1003 loops[i].getStep()}; 1004 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) { 1005 operandsDefinedAbove[i] = j; 1006 break; 1007 } 1008 } 1009 } 1010 1011 // 2. For each inner loop check that the iter_args for the immediately outer 1012 // loop are the init for the immediately inner loop and that the yields of the 1013 // return of the inner loop is the yield for the immediately outer loop. Keep 1014 // track of where the chain starts from for each loop. 1015 SmallVector<unsigned> iterArgChainStart(loops.size()); 1016 iterArgChainStart[0] = 0; 1017 for (unsigned i = 1, e = loops.size(); i < e; ++i) { 1018 // By default set the start of the chain to itself. 1019 iterArgChainStart[i] = i; 1020 auto outerloop = loops[i - 1]; 1021 auto innerLoop = loops[i]; 1022 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) { 1023 continue; 1024 } 1025 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) { 1026 continue; 1027 } 1028 auto outerloopTerminator = outerloop.getBody()->getTerminator(); 1029 if (!llvm::equal(outerloopTerminator->getOperands(), 1030 innerLoop.getResults())) { 1031 continue; 1032 } 1033 iterArgChainStart[i] = iterArgChainStart[i - 1]; 1034 } 1035 1036 // 3. Identify bands of loops such that the operands of all of them are 1037 // defined above the first loop in the band. Traverse the nest bottom-up 1038 // so that modifications don't invalidate the inner loops. 1039 for (unsigned end = loops.size(); end > 0; --end) { 1040 unsigned start = 0; 1041 for (; start < end - 1; ++start) { 1042 auto maxPos = 1043 *std::max_element(std::next(operandsDefinedAbove.begin(), start), 1044 std::next(operandsDefinedAbove.begin(), end)); 1045 if (maxPos > start) 1046 continue; 1047 if (iterArgChainStart[end - 1] > start) 1048 continue; 1049 auto band = llvm::MutableArrayRef(loops.data() + start, end - start); 1050 if (succeeded(coalesceLoops(band))) 1051 result = success(); 1052 break; 1053 } 1054 // If a band was found and transformed, keep looking at the loops above 1055 // the outermost transformed loop. 1056 if (start != end - 1) 1057 end = start + 1; 1058 } 1059 return result; 1060 } 1061 1062 void mlir::collapseParallelLoops( 1063 RewriterBase &rewriter, scf::ParallelOp loops, 1064 ArrayRef<std::vector<unsigned>> combinedDimensions) { 1065 OpBuilder::InsertionGuard g(rewriter); 1066 rewriter.setInsertionPoint(loops); 1067 Location loc = loops.getLoc(); 1068 1069 // Presort combined dimensions. 1070 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions); 1071 for (auto &dims : sortedDimensions) 1072 llvm::sort(dims); 1073 1074 // Normalize ParallelOp's iteration pattern. 1075 SmallVector<Value, 3> normalizedUpperBounds; 1076 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) { 1077 OpBuilder::InsertionGuard g2(rewriter); 1078 rewriter.setInsertionPoint(loops); 1079 Value lb = loops.getLowerBound()[i]; 1080 Value ub = loops.getUpperBound()[i]; 1081 Value step = loops.getStep()[i]; 1082 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); 1083 normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp( 1084 rewriter, loops.getLoc(), newLoopRange.size)); 1085 1086 rewriter.setInsertionPointToStart(loops.getBody()); 1087 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb, 1088 step); 1089 } 1090 1091 // Combine iteration spaces. 1092 SmallVector<Value, 3> lowerBounds, upperBounds, steps; 1093 auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1094 auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1095 for (auto &sortedDimension : sortedDimensions) { 1096 Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1097 for (auto idx : sortedDimension) { 1098 newUpperBound = rewriter.create<arith::MulIOp>( 1099 loc, newUpperBound, normalizedUpperBounds[idx]); 1100 } 1101 lowerBounds.push_back(cst0); 1102 steps.push_back(cst1); 1103 upperBounds.push_back(newUpperBound); 1104 } 1105 1106 // Create new ParallelLoop with conversions to the original induction values. 1107 // The loop below uses divisions to get the relevant range of values in the 1108 // new induction value that represent each range of the original induction 1109 // value. The remainders then determine based on that range, which iteration 1110 // of the original induction value this represents. This is a normalized value 1111 // that is un-normalized already by the previous logic. 1112 auto newPloop = rewriter.create<scf::ParallelOp>( 1113 loc, lowerBounds, upperBounds, steps, 1114 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { 1115 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { 1116 Value previous = ploopIVs[i]; 1117 unsigned numberCombinedDimensions = combinedDimensions[i].size(); 1118 // Iterate over all except the last induction value. 1119 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) { 1120 unsigned idx = combinedDimensions[i][j]; 1121 1122 // Determine the current induction value's current loop iteration 1123 Value iv = insideBuilder.create<arith::RemSIOp>( 1124 loc, previous, normalizedUpperBounds[idx]); 1125 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, 1126 loops.getRegion()); 1127 1128 // Remove the effect of the current induction value to prepare for 1129 // the next value. 1130 previous = insideBuilder.create<arith::DivSIOp>( 1131 loc, previous, normalizedUpperBounds[idx]); 1132 } 1133 1134 // The final induction value is just the remaining value. 1135 unsigned idx = combinedDimensions[i][0]; 1136 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), 1137 previous, loops.getRegion()); 1138 } 1139 }); 1140 1141 // Replace the old loop with the new loop. 1142 loops.getBody()->back().erase(); 1143 newPloop.getBody()->getOperations().splice( 1144 Block::iterator(newPloop.getBody()->back()), 1145 loops.getBody()->getOperations()); 1146 loops.erase(); 1147 } 1148 1149 // Hoist the ops within `outer` that appear before `inner`. 1150 // Such ops include the ops that have been introduced by parametric tiling. 1151 // Ops that come from triangular loops (i.e. that belong to the program slice 1152 // rooted at `outer`) and ops that have side effects cannot be hoisted. 1153 // Return failure when any op fails to hoist. 1154 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { 1155 SetVector<Operation *> forwardSlice; 1156 ForwardSliceOptions options; 1157 options.filter = [&inner](Operation *op) { 1158 return op != inner.getOperation(); 1159 }; 1160 getForwardSlice(outer.getInductionVar(), &forwardSlice, options); 1161 LogicalResult status = success(); 1162 SmallVector<Operation *, 8> toHoist; 1163 for (auto &op : outer.getBody()->without_terminator()) { 1164 // Stop when encountering the inner loop. 1165 if (&op == inner.getOperation()) 1166 break; 1167 // Skip over non-hoistable ops. 1168 if (forwardSlice.count(&op) > 0) { 1169 status = failure(); 1170 continue; 1171 } 1172 // Skip intermediate scf::ForOp, these are not considered a failure. 1173 if (isa<scf::ForOp>(op)) 1174 continue; 1175 // Skip other ops with regions. 1176 if (op.getNumRegions() > 0) { 1177 status = failure(); 1178 continue; 1179 } 1180 // Skip if op has side effects. 1181 // TODO: loads to immutable memory regions are ok. 1182 if (!isMemoryEffectFree(&op)) { 1183 status = failure(); 1184 continue; 1185 } 1186 toHoist.push_back(&op); 1187 } 1188 auto *outerForOp = outer.getOperation(); 1189 for (auto *op : toHoist) 1190 op->moveBefore(outerForOp); 1191 return status; 1192 } 1193 1194 // Traverse the interTile and intraTile loops and try to hoist ops such that 1195 // bands of perfectly nested loops are isolated. 1196 // Return failure if either perfect interTile or perfect intraTile bands cannot 1197 // be formed. 1198 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) { 1199 LogicalResult status = success(); 1200 const Loops &interTile = tileLoops.first; 1201 const Loops &intraTile = tileLoops.second; 1202 auto size = interTile.size(); 1203 assert(size == intraTile.size()); 1204 if (size <= 1) 1205 return success(); 1206 for (unsigned s = 1; s < size; ++s) 1207 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s]) 1208 : failure(); 1209 for (unsigned s = 1; s < size; ++s) 1210 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s]) 1211 : failure(); 1212 return status; 1213 } 1214 1215 /// Collect perfectly nested loops starting from `rootForOps`. Loops are 1216 /// perfectly nested if each loop is the first and only non-terminator operation 1217 /// in the parent loop. Collect at most `maxLoops` loops and append them to 1218 /// `forOps`. 1219 template <typename T> 1220 static void getPerfectlyNestedLoopsImpl( 1221 SmallVectorImpl<T> &forOps, T rootForOp, 1222 unsigned maxLoops = std::numeric_limits<unsigned>::max()) { 1223 for (unsigned i = 0; i < maxLoops; ++i) { 1224 forOps.push_back(rootForOp); 1225 Block &body = rootForOp.getRegion().front(); 1226 if (body.begin() != std::prev(body.end(), 2)) 1227 return; 1228 1229 rootForOp = dyn_cast<T>(&body.front()); 1230 if (!rootForOp) 1231 return; 1232 } 1233 } 1234 1235 static Loops stripmineSink(scf::ForOp forOp, Value factor, 1236 ArrayRef<scf::ForOp> targets) { 1237 auto originalStep = forOp.getStep(); 1238 auto iv = forOp.getInductionVar(); 1239 1240 OpBuilder b(forOp); 1241 forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor)); 1242 1243 Loops innerLoops; 1244 for (auto t : targets) { 1245 // Save information for splicing ops out of t when done 1246 auto begin = t.getBody()->begin(); 1247 auto nOps = t.getBody()->getOperations().size(); 1248 1249 // Insert newForOp before the terminator of `t`. 1250 auto b = OpBuilder::atBlockTerminator((t.getBody())); 1251 Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep()); 1252 Value ub = 1253 b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped); 1254 1255 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. 1256 auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep); 1257 newForOp.getBody()->getOperations().splice( 1258 newForOp.getBody()->getOperations().begin(), 1259 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); 1260 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), 1261 newForOp.getRegion()); 1262 1263 innerLoops.push_back(newForOp); 1264 } 1265 1266 return innerLoops; 1267 } 1268 1269 // Stripmines a `forOp` by `factor` and sinks it under a single `target`. 1270 // Returns the new for operation, nested immediately under `target`. 1271 template <typename SizeType> 1272 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor, 1273 scf::ForOp target) { 1274 // TODO: Use cheap structural assertions that targets are nested under 1275 // forOp and that targets are not nested under each other when DominanceInfo 1276 // exposes the capability. It seems overkill to construct a whole function 1277 // dominance tree at this point. 1278 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target)); 1279 assert(res.size() == 1 && "Expected 1 inner forOp"); 1280 return res[0]; 1281 } 1282 1283 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps, 1284 ArrayRef<Value> sizes, 1285 ArrayRef<scf::ForOp> targets) { 1286 SmallVector<SmallVector<scf::ForOp, 8>, 8> res; 1287 SmallVector<scf::ForOp, 8> currentTargets(targets); 1288 for (auto it : llvm::zip(forOps, sizes)) { 1289 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); 1290 res.push_back(step); 1291 currentTargets = step; 1292 } 1293 return res; 1294 } 1295 1296 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, 1297 scf::ForOp target) { 1298 SmallVector<scf::ForOp, 8> res; 1299 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) { 1300 assert(loops.size() == 1); 1301 res.push_back(loops[0]); 1302 } 1303 return res; 1304 } 1305 1306 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) { 1307 // Collect perfectly nested loops. If more size values provided than nested 1308 // loops available, truncate `sizes`. 1309 SmallVector<scf::ForOp, 4> forOps; 1310 forOps.reserve(sizes.size()); 1311 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); 1312 if (forOps.size() < sizes.size()) 1313 sizes = sizes.take_front(forOps.size()); 1314 1315 return ::tile(forOps, sizes, forOps.back()); 1316 } 1317 1318 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, 1319 scf::ForOp root) { 1320 getPerfectlyNestedLoopsImpl(nestedLoops, root); 1321 } 1322 1323 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, 1324 ArrayRef<int64_t> sizes) { 1325 // Collect perfectly nested loops. If more size values provided than nested 1326 // loops available, truncate `sizes`. 1327 SmallVector<scf::ForOp, 4> forOps; 1328 forOps.reserve(sizes.size()); 1329 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); 1330 if (forOps.size() < sizes.size()) 1331 sizes = sizes.take_front(forOps.size()); 1332 1333 // Compute the tile sizes such that i-th outer loop executes size[i] 1334 // iterations. Given that the loop current executes 1335 // numIterations = ceildiv((upperBound - lowerBound), step) 1336 // iterations, we need to tile with size ceildiv(numIterations, size[i]). 1337 SmallVector<Value, 4> tileSizes; 1338 tileSizes.reserve(sizes.size()); 1339 for (unsigned i = 0, e = sizes.size(); i < e; ++i) { 1340 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); 1341 1342 auto forOp = forOps[i]; 1343 OpBuilder builder(forOp); 1344 auto loc = forOp.getLoc(); 1345 Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(), 1346 forOp.getLowerBound()); 1347 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); 1348 Value iterationsPerBlock = 1349 ceilDivPositive(builder, loc, numIterations, sizes[i]); 1350 tileSizes.push_back(iterationsPerBlock); 1351 } 1352 1353 // Call parametric tiling with the given sizes. 1354 auto intraTile = tile(forOps, tileSizes, forOps.back()); 1355 TileLoops tileLoops = std::make_pair(forOps, intraTile); 1356 1357 // TODO: for now we just ignore the result of band isolation. 1358 // In the future, mapping decisions may be impacted by the ability to 1359 // isolate perfectly nested bands. 1360 (void)tryIsolateBands(tileLoops); 1361 1362 return tileLoops; 1363 } 1364 1365 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, 1366 scf::ForallOp source, 1367 RewriterBase &rewriter) { 1368 unsigned numTargetOuts = target.getNumResults(); 1369 unsigned numSourceOuts = source.getNumResults(); 1370 1371 // Create fused shared_outs. 1372 SmallVector<Value> fusedOuts; 1373 llvm::append_range(fusedOuts, target.getOutputs()); 1374 llvm::append_range(fusedOuts, source.getOutputs()); 1375 1376 // Create a new scf.forall op after the source loop. 1377 rewriter.setInsertionPointAfter(source); 1378 scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( 1379 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), 1380 source.getMixedStep(), fusedOuts, source.getMapping()); 1381 1382 // Map control operands. 1383 IRMapping mapping; 1384 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); 1385 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); 1386 1387 // Map shared outs. 1388 mapping.map(target.getRegionIterArgs(), 1389 fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); 1390 mapping.map(source.getRegionIterArgs(), 1391 fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); 1392 1393 // Append everything except the terminator into the fused operation. 1394 rewriter.setInsertionPointToStart(fusedLoop.getBody()); 1395 for (Operation &op : target.getBody()->without_terminator()) 1396 rewriter.clone(op, mapping); 1397 for (Operation &op : source.getBody()->without_terminator()) 1398 rewriter.clone(op, mapping); 1399 1400 // Fuse the old terminator in_parallel ops into the new one. 1401 scf::InParallelOp targetTerm = target.getTerminator(); 1402 scf::InParallelOp sourceTerm = source.getTerminator(); 1403 scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); 1404 rewriter.setInsertionPointToStart(fusedTerm.getBody()); 1405 for (Operation &op : targetTerm.getYieldingOps()) 1406 rewriter.clone(op, mapping); 1407 for (Operation &op : sourceTerm.getYieldingOps()) 1408 rewriter.clone(op, mapping); 1409 1410 // Replace old loops by substituting their uses by results of the fused loop. 1411 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); 1412 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); 1413 1414 return fusedLoop; 1415 } 1416 1417 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, 1418 scf::ForOp source, 1419 RewriterBase &rewriter) { 1420 unsigned numTargetOuts = target.getNumResults(); 1421 unsigned numSourceOuts = source.getNumResults(); 1422 1423 // Create fused init_args, with target's init_args before source's init_args. 1424 SmallVector<Value> fusedInitArgs; 1425 llvm::append_range(fusedInitArgs, target.getInitArgs()); 1426 llvm::append_range(fusedInitArgs, source.getInitArgs()); 1427 1428 // Create a new scf.for op after the source loop (with scf.yield terminator 1429 // (without arguments) only in case its init_args is empty). 1430 rewriter.setInsertionPointAfter(source); 1431 scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( 1432 source.getLoc(), source.getLowerBound(), source.getUpperBound(), 1433 source.getStep(), fusedInitArgs); 1434 1435 // Map original induction variables and operands to those of the fused loop. 1436 IRMapping mapping; 1437 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); 1438 mapping.map(target.getRegionIterArgs(), 1439 fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); 1440 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); 1441 mapping.map(source.getRegionIterArgs(), 1442 fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); 1443 1444 // Merge target's body into the new (fused) for loop and then source's body. 1445 rewriter.setInsertionPointToStart(fusedLoop.getBody()); 1446 for (Operation &op : target.getBody()->without_terminator()) 1447 rewriter.clone(op, mapping); 1448 for (Operation &op : source.getBody()->without_terminator()) 1449 rewriter.clone(op, mapping); 1450 1451 // Build fused yield results by appropriately mapping original yield operands. 1452 SmallVector<Value> yieldResults; 1453 for (Value operand : target.getBody()->getTerminator()->getOperands()) 1454 yieldResults.push_back(mapping.lookupOrDefault(operand)); 1455 for (Value operand : source.getBody()->getTerminator()->getOperands()) 1456 yieldResults.push_back(mapping.lookupOrDefault(operand)); 1457 if (!yieldResults.empty()) 1458 rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); 1459 1460 // Replace old loops by substituting their uses by results of the fused loop. 1461 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); 1462 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); 1463 1464 return fusedLoop; 1465 } 1466 1467 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter, 1468 scf::ForallOp forallOp) { 1469 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound(); 1470 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound(); 1471 SmallVector<OpFoldResult> steps = forallOp.getMixedStep(); 1472 1473 if (llvm::all_of( 1474 lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && 1475 llvm::all_of( 1476 steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { 1477 return forallOp; 1478 } 1479 1480 SmallVector<OpFoldResult> newLbs, newUbs, newSteps; 1481 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { 1482 Range normalizedLoopParams = 1483 emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step); 1484 newLbs.push_back(normalizedLoopParams.offset); 1485 newUbs.push_back(normalizedLoopParams.size); 1486 newSteps.push_back(normalizedLoopParams.stride); 1487 } 1488 1489 auto normalizedForallOp = rewriter.create<scf::ForallOp>( 1490 forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(), 1491 forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); 1492 1493 rewriter.inlineRegionBefore(forallOp.getBodyRegion(), 1494 normalizedForallOp.getBodyRegion(), 1495 normalizedForallOp.getBodyRegion().begin()); 1496 1497 rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp); 1498 return success(); 1499 } 1500