14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// 24bcd08ebSStephan Herhut // 34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information. 54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64bcd08ebSStephan Herhut // 74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===// 84bcd08ebSStephan Herhut // 94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and 104bcd08ebSStephan Herhut // vectorization. 114bcd08ebSStephan Herhut // 124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===// 134bcd08ebSStephan Herhut 14*039b969bSMichele Scuttari #include "PassDetail.h" 15755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 164bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h" 17a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 19*039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h" 244bcd08ebSStephan Herhut #include "mlir/IR/BlockAndValueMapping.h" 253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h" 263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h" 284bcd08ebSStephan Herhut 294bcd08ebSStephan Herhut using namespace mlir; 304bcd08ebSStephan Herhut using scf::ForOp; 314bcd08ebSStephan Herhut using scf::ParallelOp; 324bcd08ebSStephan Herhut 334bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant 344bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This 354bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and 364bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized. 374bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) { 384bcd08ebSStephan Herhut SmallVector<int64_t, 2> constantIndices; 39c0342a2dSJacques Pienaar constantIndices.reserve(op.getUpperBound().size()); 40c0342a2dSJacques Pienaar for (auto bound : op.getUpperBound()) { 414bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>(); 424bcd08ebSStephan Herhut if (!minOp) 434bcd08ebSStephan Herhut return; 444bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max(); 4504235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) { 464bcd08ebSStephan Herhut if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 474bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue()); 484bcd08ebSStephan Herhut } 494bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max()) 504bcd08ebSStephan Herhut return; 514bcd08ebSStephan Herhut constantIndices.push_back(minConstant); 524bcd08ebSStephan Herhut } 534bcd08ebSStephan Herhut 544bcd08ebSStephan Herhut OpBuilder b(op); 554bcd08ebSStephan Herhut BlockAndValueMapping map; 564bcd08ebSStephan Herhut Value cond; 57c0342a2dSJacques Pienaar for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { 58a54f4eaeSMogball Value constant = 59a54f4eaeSMogball b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound)); 60a54f4eaeSMogball Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 614bcd08ebSStephan Herhut std::get<0>(bound), constant); 62a54f4eaeSMogball cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp; 634bcd08ebSStephan Herhut map.map(std::get<0>(bound), constant); 644bcd08ebSStephan Herhut } 654bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 664bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 674bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation()); 684bcd08ebSStephan Herhut op.erase(); 694bcd08ebSStephan Herhut } 704bcd08ebSStephan Herhut 714bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into 724bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is 734bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that 744bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized. 754bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) { 76c0342a2dSJacques Pienaar auto bound = op.getUpperBound(); 774bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>(); 784bcd08ebSStephan Herhut if (!minOp) 794bcd08ebSStephan Herhut return; 804bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max(); 8104235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) { 824bcd08ebSStephan Herhut if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 834bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue()); 844bcd08ebSStephan Herhut } 854bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max()) 864bcd08ebSStephan Herhut return; 874bcd08ebSStephan Herhut 884bcd08ebSStephan Herhut OpBuilder b(op); 894bcd08ebSStephan Herhut BlockAndValueMapping map; 90a54f4eaeSMogball Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant); 91a54f4eaeSMogball Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 92a54f4eaeSMogball bound, constant); 934bcd08ebSStephan Herhut map.map(bound, constant); 944bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 954bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 964bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation()); 974bcd08ebSStephan Herhut op.erase(); 984bcd08ebSStephan Herhut } 994bcd08ebSStephan Herhut 1003a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly 1013a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed 1023a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any). 1038e8b70aaSMatthias Springer /// 1048e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new 1058e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the 1068e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if. 1078e8b70aaSMatthias Springer /// 1088e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary 1098e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`. 1108e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not. 1110f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, 1120f3544d1SMatthias Springer ForOp &partialIteration, Value &splitBound) { 1133a41ff48SMatthias Springer RewriterBase::InsertionGuard guard(b); 114c0342a2dSJacques Pienaar auto lbInt = getConstantIntValue(forOp.getLowerBound()); 115c0342a2dSJacques Pienaar auto ubInt = getConstantIntValue(forOp.getUpperBound()); 116c0342a2dSJacques Pienaar auto stepInt = getConstantIntValue(forOp.getStep()); 1173a41ff48SMatthias Springer 1183a41ff48SMatthias Springer // No specialization necessary if step already divides upper bound evenly. 1193a41ff48SMatthias Springer if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0) 1203a41ff48SMatthias Springer return failure(); 1213a41ff48SMatthias Springer // No specialization necessary if step size is 1. 1223a41ff48SMatthias Springer if (stepInt == static_cast<int64_t>(1)) 1233a41ff48SMatthias Springer return failure(); 1243a41ff48SMatthias Springer 1253a41ff48SMatthias Springer auto loc = forOp.getLoc(); 1260c360829SMatthias Springer AffineExpr sym0, sym1, sym2; 1270c360829SMatthias Springer bindSymbols(b.getContext(), sym0, sym1, sym2); 1283a41ff48SMatthias Springer // New upper bound: %ub - (%ub - %lb) mod %step 1290c360829SMatthias Springer auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)}); 130767974f3SMatthias Springer b.setInsertionPoint(forOp); 131c0342a2dSJacques Pienaar splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, 132c0342a2dSJacques Pienaar ValueRange{forOp.getLowerBound(), 133c0342a2dSJacques Pienaar forOp.getUpperBound(), 134c0342a2dSJacques Pienaar forOp.getStep()}); 1353a41ff48SMatthias Springer 1360f3544d1SMatthias Springer // Create ForOp for partial iteration. 1370f3544d1SMatthias Springer b.setInsertionPointAfter(forOp); 1380f3544d1SMatthias Springer partialIteration = cast<ForOp>(b.clone(*forOp.getOperation())); 139c0342a2dSJacques Pienaar partialIteration.getLowerBoundMutable().assign(splitBound); 1400f3544d1SMatthias Springer forOp.replaceAllUsesWith(partialIteration->getResults()); 141c0342a2dSJacques Pienaar partialIteration.getInitArgsMutable().assign(forOp->getResults()); 1420f3544d1SMatthias Springer 1433a41ff48SMatthias Springer // Set new upper loop bound. 144c0342a2dSJacques Pienaar b.updateRootInPlace( 145c0342a2dSJacques Pienaar forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); 1463a41ff48SMatthias Springer 1473a41ff48SMatthias Springer return success(); 1483a41ff48SMatthias Springer } 1493a41ff48SMatthias Springer 150a9cff97fSMatthias Springer template <typename OpTy, bool IsMin> 1510f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, 1520f3544d1SMatthias Springer ForOp partialIteration, 1530f3544d1SMatthias Springer Value previousUb) { 1540f3544d1SMatthias Springer Value mainIv = forOp.getInductionVar(); 1550f3544d1SMatthias Springer Value partialIv = partialIteration.getInductionVar(); 156c0342a2dSJacques Pienaar assert(forOp.getStep() == partialIteration.getStep() && 1570f3544d1SMatthias Springer "expected same step in main and partial loop"); 158c0342a2dSJacques Pienaar Value step = forOp.getStep(); 1590f3544d1SMatthias Springer 160a9cff97fSMatthias Springer forOp.walk([&](OpTy affineOp) { 161c57c4f88SMatthias Springer AffineMap map = affineOp.getAffineMap(); 162c57c4f88SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 1630f3544d1SMatthias Springer affineOp.operands(), IsMin, mainIv, 1640f3544d1SMatthias Springer previousUb, step, 165a9cff97fSMatthias Springer /*insideLoop=*/true); 166a9cff97fSMatthias Springer }); 1670f3544d1SMatthias Springer partialIteration.walk([&](OpTy affineOp) { 168c57c4f88SMatthias Springer AffineMap map = affineOp.getAffineMap(); 169c57c4f88SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 1700f3544d1SMatthias Springer affineOp.operands(), IsMin, partialIv, 1710f3544d1SMatthias Springer previousUb, step, /*insideLoop=*/false); 172a9cff97fSMatthias Springer }); 1738e8b70aaSMatthias Springer } 1748e8b70aaSMatthias Springer 1758e8b70aaSMatthias Springer LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter, 176bc194a5bSMatthias Springer ForOp forOp, 1770f3544d1SMatthias Springer ForOp &partialIteration) { 178c0342a2dSJacques Pienaar Value previousUb = forOp.getUpperBound(); 1798e8b70aaSMatthias Springer Value splitBound; 1800f3544d1SMatthias Springer if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound))) 1818e8b70aaSMatthias Springer return failure(); 1828e8b70aaSMatthias Springer 183a9cff97fSMatthias Springer // Rewrite affine.min and affine.max ops. 184a9cff97fSMatthias Springer rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>( 1850f3544d1SMatthias Springer rewriter, forOp, partialIteration, previousUb); 186a9cff97fSMatthias Springer rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>( 1870f3544d1SMatthias Springer rewriter, forOp, partialIteration, previousUb); 1888e8b70aaSMatthias Springer 1898e8b70aaSMatthias Springer return success(); 1908e8b70aaSMatthias Springer } 1918e8b70aaSMatthias Springer 1923a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; 193bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 1943a41ff48SMatthias Springer 1953a41ff48SMatthias Springer namespace { 1963a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { 197bc194a5bSMatthias Springer ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial) 198bc194a5bSMatthias Springer : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {} 1993a41ff48SMatthias Springer 2003a41ff48SMatthias Springer LogicalResult matchAndRewrite(ForOp forOp, 2013a41ff48SMatthias Springer PatternRewriter &rewriter) const override { 202bc194a5bSMatthias Springer // Do not peel already peeled loops. 2033a41ff48SMatthias Springer if (forOp->hasAttr(kPeeledLoopLabel)) 2043a41ff48SMatthias Springer return failure(); 205bc194a5bSMatthias Springer if (skipPartial) { 2060f3544d1SMatthias Springer // No peeling of loops inside the partial iteration of another peeled 2070f3544d1SMatthias Springer // loop. 208bc194a5bSMatthias Springer Operation *op = forOp.getOperation(); 2090f3544d1SMatthias Springer while ((op = op->getParentOfType<scf::ForOp>())) { 210bc194a5bSMatthias Springer if (op->hasAttr(kPartialIterationLabel)) 211bc194a5bSMatthias Springer return failure(); 212bc194a5bSMatthias Springer } 213bc194a5bSMatthias Springer } 214bc194a5bSMatthias Springer // Apply loop peeling. 2150f3544d1SMatthias Springer scf::ForOp partialIteration; 2160f3544d1SMatthias Springer if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration))) 2173a41ff48SMatthias Springer return failure(); 2183a41ff48SMatthias Springer // Apply label, so that the same loop is not rewritten a second time. 2190f3544d1SMatthias Springer partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 2203a41ff48SMatthias Springer rewriter.updateRootInPlace(forOp, [&]() { 2213a41ff48SMatthias Springer forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 2223a41ff48SMatthias Springer }); 2230f3544d1SMatthias Springer partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 2243a41ff48SMatthias Springer return success(); 2253a41ff48SMatthias Springer } 226bc194a5bSMatthias Springer 227bc194a5bSMatthias Springer /// If set to true, loops inside partial iterations of another peeled loop 228bc194a5bSMatthias Springer /// are not peeled. This reduces the size of the generated code. Partial 229bc194a5bSMatthias Springer /// iterations are not usually performance critical. 230bc194a5bSMatthias Springer /// Note: Takes into account the entire chain of parent operations, not just 231bc194a5bSMatthias Springer /// the direct parent. 232bc194a5bSMatthias Springer bool skipPartial; 2333a41ff48SMatthias Springer }; 2343a41ff48SMatthias Springer } // namespace 2353a41ff48SMatthias Springer 2364bcd08ebSStephan Herhut namespace { 237*039b969bSMichele Scuttari struct ParallelLoopSpecialization 238*039b969bSMichele Scuttari : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> { 23941574554SRiver Riddle void runOnOperation() override { 24054998986SStella Laurenzo getOperation()->walk( 2414bcd08ebSStephan Herhut [](ParallelOp op) { specializeParallelLoopForUnrolling(op); }); 2424bcd08ebSStephan Herhut } 2434bcd08ebSStephan Herhut }; 2444bcd08ebSStephan Herhut 245*039b969bSMichele Scuttari struct ForLoopSpecialization 246*039b969bSMichele Scuttari : public SCFForLoopSpecializationBase<ForLoopSpecialization> { 24741574554SRiver Riddle void runOnOperation() override { 24854998986SStella Laurenzo getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); }); 2494bcd08ebSStephan Herhut } 2504bcd08ebSStephan Herhut }; 2513a41ff48SMatthias Springer 252*039b969bSMichele Scuttari struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> { 25341574554SRiver Riddle void runOnOperation() override { 25454998986SStella Laurenzo auto *parentOp = getOperation(); 25554998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext(); 2563a41ff48SMatthias Springer RewritePatternSet patterns(ctx); 257bc194a5bSMatthias Springer patterns.add<ForLoopPeelingPattern>(ctx, skipPartial); 25854998986SStella Laurenzo (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); 2593a41ff48SMatthias Springer 260bc194a5bSMatthias Springer // Drop the markers. 26154998986SStella Laurenzo parentOp->walk([](Operation *op) { 262bc194a5bSMatthias Springer op->removeAttr(kPeeledLoopLabel); 263bc194a5bSMatthias Springer op->removeAttr(kPartialIterationLabel); 264bc194a5bSMatthias Springer }); 2653a41ff48SMatthias Springer } 2663a41ff48SMatthias Springer }; 2674bcd08ebSStephan Herhut } // namespace 268*039b969bSMichele Scuttari 269*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() { 270*039b969bSMichele Scuttari return std::make_unique<ParallelLoopSpecialization>(); 271*039b969bSMichele Scuttari } 272*039b969bSMichele Scuttari 273*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() { 274*039b969bSMichele Scuttari return std::make_unique<ForLoopSpecialization>(); 275*039b969bSMichele Scuttari } 276*039b969bSMichele Scuttari 277*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopPeelingPass() { 278*039b969bSMichele Scuttari return std::make_unique<ForLoopPeeling>(); 279*039b969bSMichele Scuttari } 280