14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// 24bcd08ebSStephan Herhut // 34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information. 54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64bcd08ebSStephan Herhut // 74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===// 84bcd08ebSStephan Herhut // 94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and 104bcd08ebSStephan Herhut // vectorization. 114bcd08ebSStephan Herhut // 124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===// 134bcd08ebSStephan Herhut 1467d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1567d0d7acSMichele Scuttari 16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 174bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h" 244d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h" 263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h" 284bcd08ebSStephan Herhut 2967d0d7acSMichele Scuttari namespace mlir { 3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPPEELING 3167d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION 3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION 3367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 3467d0d7acSMichele Scuttari } // namespace mlir 3567d0d7acSMichele Scuttari 364bcd08ebSStephan Herhut using namespace mlir; 374c48f016SMatthias Springer using namespace mlir::affine; 384bcd08ebSStephan Herhut using scf::ForOp; 394bcd08ebSStephan Herhut using scf::ParallelOp; 404bcd08ebSStephan Herhut 414bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant 424bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This 434bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and 444bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized. 454bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) { 464bcd08ebSStephan Herhut SmallVector<int64_t, 2> constantIndices; 47c0342a2dSJacques Pienaar constantIndices.reserve(op.getUpperBound().size()); 48c0342a2dSJacques Pienaar for (auto bound : op.getUpperBound()) { 494bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>(); 504bcd08ebSStephan Herhut if (!minOp) 514bcd08ebSStephan Herhut return; 524bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max(); 5304235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) { 541609f1c2Slong.chen if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr)) 554bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue()); 564bcd08ebSStephan Herhut } 574bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max()) 584bcd08ebSStephan Herhut return; 594bcd08ebSStephan Herhut constantIndices.push_back(minConstant); 604bcd08ebSStephan Herhut } 614bcd08ebSStephan Herhut 624bcd08ebSStephan Herhut OpBuilder b(op); 634d67b278SJeff Niu IRMapping map; 644bcd08ebSStephan Herhut Value cond; 65c0342a2dSJacques Pienaar for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { 66a54f4eaeSMogball Value constant = 67a54f4eaeSMogball b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound)); 68a54f4eaeSMogball Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 694bcd08ebSStephan Herhut std::get<0>(bound), constant); 70a54f4eaeSMogball cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp; 714bcd08ebSStephan Herhut map.map(std::get<0>(bound), constant); 724bcd08ebSStephan Herhut } 734bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 744bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 754bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation()); 764bcd08ebSStephan Herhut op.erase(); 774bcd08ebSStephan Herhut } 784bcd08ebSStephan Herhut 794bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into 804bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is 814bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that 824bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized. 834bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) { 84c0342a2dSJacques Pienaar auto bound = op.getUpperBound(); 854bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>(); 864bcd08ebSStephan Herhut if (!minOp) 874bcd08ebSStephan Herhut return; 884bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max(); 8904235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) { 901609f1c2Slong.chen if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr)) 914bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue()); 924bcd08ebSStephan Herhut } 934bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max()) 944bcd08ebSStephan Herhut return; 954bcd08ebSStephan Herhut 964bcd08ebSStephan Herhut OpBuilder b(op); 974d67b278SJeff Niu IRMapping map; 98a54f4eaeSMogball Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant); 99a54f4eaeSMogball Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 100a54f4eaeSMogball bound, constant); 1014bcd08ebSStephan Herhut map.map(bound, constant); 1024bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 1034bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 1044bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation()); 1054bcd08ebSStephan Herhut op.erase(); 1064bcd08ebSStephan Herhut } 1074bcd08ebSStephan Herhut 1083a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly 1093a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed 1103a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any). 1118e8b70aaSMatthias Springer /// 1128e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new 1138e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the 1148e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if. 1158e8b70aaSMatthias Springer /// 1168e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary 1178e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`. 1188e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not. 1190f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, 1200f3544d1SMatthias Springer ForOp &partialIteration, Value &splitBound) { 1213a41ff48SMatthias Springer RewriterBase::InsertionGuard guard(b); 122c0342a2dSJacques Pienaar auto lbInt = getConstantIntValue(forOp.getLowerBound()); 123c0342a2dSJacques Pienaar auto ubInt = getConstantIntValue(forOp.getUpperBound()); 124c0342a2dSJacques Pienaar auto stepInt = getConstantIntValue(forOp.getStep()); 1253a41ff48SMatthias Springer 126f6f1ab9dSFelix Schneider // No specialization necessary if step size is 1. Also bail out in case of an 127f6f1ab9dSFelix Schneider // invalid zero or negative step which might have happened during folding. 128f6f1ab9dSFelix Schneider if (stepInt && *stepInt <= 1) 1293a41ff48SMatthias Springer return failure(); 1303a41ff48SMatthias Springer 13196901f1bSMatthias Springer // No specialization necessary if step already divides upper bound evenly. 13296901f1bSMatthias Springer // Fast path: lb, ub and step are constants. 13396901f1bSMatthias Springer if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0) 13496901f1bSMatthias Springer return failure(); 13596901f1bSMatthias Springer // Slow path: Examine the ops that define lb, ub and step. 1360c360829SMatthias Springer AffineExpr sym0, sym1, sym2; 1370c360829SMatthias Springer bindSymbols(b.getContext(), sym0, sym1, sym2); 13896901f1bSMatthias Springer SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(), 13996901f1bSMatthias Springer forOp.getStep()}; 14096901f1bSMatthias Springer AffineMap map = AffineMap::get(0, 3, {(sym1 - sym0) % sym2}); 14196901f1bSMatthias Springer affine::fullyComposeAffineMapAndOperands(&map, &operands); 14296901f1bSMatthias Springer if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(0))) 14396901f1bSMatthias Springer if (constExpr.getValue() == 0) 14496901f1bSMatthias Springer return failure(); 14596901f1bSMatthias Springer 1463a41ff48SMatthias Springer // New upper bound: %ub - (%ub - %lb) mod %step 1470c360829SMatthias Springer auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)}); 148767974f3SMatthias Springer b.setInsertionPoint(forOp); 14996901f1bSMatthias Springer auto loc = forOp.getLoc(); 150c0342a2dSJacques Pienaar splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, 151c0342a2dSJacques Pienaar ValueRange{forOp.getLowerBound(), 152c0342a2dSJacques Pienaar forOp.getUpperBound(), 153c0342a2dSJacques Pienaar forOp.getStep()}); 1543a41ff48SMatthias Springer 1550f3544d1SMatthias Springer // Create ForOp for partial iteration. 1560f3544d1SMatthias Springer b.setInsertionPointAfter(forOp); 1570f3544d1SMatthias Springer partialIteration = cast<ForOp>(b.clone(*forOp.getOperation())); 158c0342a2dSJacques Pienaar partialIteration.getLowerBoundMutable().assign(splitBound); 15961f37758SMatthias Springer b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults()); 160c0342a2dSJacques Pienaar partialIteration.getInitArgsMutable().assign(forOp->getResults()); 1610f3544d1SMatthias Springer 1623a41ff48SMatthias Springer // Set new upper loop bound. 1635fcf907bSMatthias Springer b.modifyOpInPlace(forOp, 1645fcf907bSMatthias Springer [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); 1653a41ff48SMatthias Springer 1663a41ff48SMatthias Springer return success(); 1673a41ff48SMatthias Springer } 1683a41ff48SMatthias Springer 1690f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, 1700f3544d1SMatthias Springer ForOp partialIteration, 1710f3544d1SMatthias Springer Value previousUb) { 1720f3544d1SMatthias Springer Value mainIv = forOp.getInductionVar(); 1730f3544d1SMatthias Springer Value partialIv = partialIteration.getInductionVar(); 174c0342a2dSJacques Pienaar assert(forOp.getStep() == partialIteration.getStep() && 1750f3544d1SMatthias Springer "expected same step in main and partial loop"); 176c0342a2dSJacques Pienaar Value step = forOp.getStep(); 1770f3544d1SMatthias Springer 1783a5811a3SMatthias Springer forOp.walk([&](Operation *affineOp) { 1793a5811a3SMatthias Springer if (!isa<AffineMinOp, AffineMaxOp>(affineOp)) 1803a5811a3SMatthias Springer return WalkResult::advance(); 1813a5811a3SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb, 1823a5811a3SMatthias Springer step, 183a9cff97fSMatthias Springer /*insideLoop=*/true); 1843a5811a3SMatthias Springer return WalkResult::advance(); 185a9cff97fSMatthias Springer }); 1863a5811a3SMatthias Springer partialIteration.walk([&](Operation *affineOp) { 1873a5811a3SMatthias Springer if (!isa<AffineMinOp, AffineMaxOp>(affineOp)) 1883a5811a3SMatthias Springer return WalkResult::advance(); 1893a5811a3SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb, 1903a5811a3SMatthias Springer step, /*insideLoop=*/false); 1913a5811a3SMatthias Springer return WalkResult::advance(); 192a9cff97fSMatthias Springer }); 1938e8b70aaSMatthias Springer } 1948e8b70aaSMatthias Springer 195bb2ae985SNicolas Vasilache LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter, 196bc194a5bSMatthias Springer ForOp forOp, 1970f3544d1SMatthias Springer ForOp &partialIteration) { 198c0342a2dSJacques Pienaar Value previousUb = forOp.getUpperBound(); 1998e8b70aaSMatthias Springer Value splitBound; 2000f3544d1SMatthias Springer if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound))) 2018e8b70aaSMatthias Springer return failure(); 2028e8b70aaSMatthias Springer 203a9cff97fSMatthias Springer // Rewrite affine.min and affine.max ops. 2043a5811a3SMatthias Springer rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb); 2058e8b70aaSMatthias Springer 2068e8b70aaSMatthias Springer return success(); 2078e8b70aaSMatthias Springer } 2088e8b70aaSMatthias Springer 209a9c417c2SHugo Trachino /// Rewrites the original scf::ForOp as two scf::ForOp Ops, the first 210a9c417c2SHugo Trachino /// scf::ForOp corresponds to the first iteration of the loop which can be 211a9c417c2SHugo Trachino /// canonicalized away in the following optimizations. The second loop Op 212a9c417c2SHugo Trachino /// contains the remaining iterations, with a lower bound updated as the 213a9c417c2SHugo Trachino /// original lower bound plus the step (i.e. skips the first iteration). 214bd6a2452SVivian LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp, 215bd6a2452SVivian ForOp &firstIteration) { 216bd6a2452SVivian RewriterBase::InsertionGuard guard(b); 217bd6a2452SVivian auto lbInt = getConstantIntValue(forOp.getLowerBound()); 218bd6a2452SVivian auto ubInt = getConstantIntValue(forOp.getUpperBound()); 219bd6a2452SVivian auto stepInt = getConstantIntValue(forOp.getStep()); 220bd6a2452SVivian 221bd6a2452SVivian // Peeling is not needed if there is one or less iteration. 222a9d1feadSVivian if (lbInt && ubInt && stepInt && ceil(float(*ubInt - *lbInt) / *stepInt) <= 1) 223bd6a2452SVivian return failure(); 224bd6a2452SVivian 225bd6a2452SVivian AffineExpr lbSymbol, stepSymbol; 226bd6a2452SVivian bindSymbols(b.getContext(), lbSymbol, stepSymbol); 227bd6a2452SVivian 228bd6a2452SVivian // New lower bound for main loop: %lb + %step 229bd6a2452SVivian auto ubMap = AffineMap::get(0, 2, {lbSymbol + stepSymbol}); 230bd6a2452SVivian b.setInsertionPoint(forOp); 231bd6a2452SVivian auto loc = forOp.getLoc(); 232bd6a2452SVivian Value splitBound = b.createOrFold<AffineApplyOp>( 233bd6a2452SVivian loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()}); 234bd6a2452SVivian 235bd6a2452SVivian // Peel the first iteration. 236bd6a2452SVivian IRMapping map; 237bd6a2452SVivian map.map(forOp.getUpperBound(), splitBound); 238bd6a2452SVivian firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map)); 239bd6a2452SVivian 240bd6a2452SVivian // Update main loop with new lower bound. 2415fcf907bSMatthias Springer b.modifyOpInPlace(forOp, [&]() { 242bd6a2452SVivian forOp.getInitArgsMutable().assign(firstIteration->getResults()); 243bd6a2452SVivian forOp.getLowerBoundMutable().assign(splitBound); 244bd6a2452SVivian }); 245bd6a2452SVivian 246bd6a2452SVivian return success(); 247bd6a2452SVivian } 248bd6a2452SVivian 2493a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; 250bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 2513a41ff48SMatthias Springer 2523a41ff48SMatthias Springer namespace { 2533a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { 254bd6a2452SVivian ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial) 255bd6a2452SVivian : OpRewritePattern<ForOp>(ctx), peelFront(peelFront), 256bd6a2452SVivian skipPartial(skipPartial) {} 2573a41ff48SMatthias Springer 2583a41ff48SMatthias Springer LogicalResult matchAndRewrite(ForOp forOp, 2593a41ff48SMatthias Springer PatternRewriter &rewriter) const override { 260bc194a5bSMatthias Springer // Do not peel already peeled loops. 2613a41ff48SMatthias Springer if (forOp->hasAttr(kPeeledLoopLabel)) 2623a41ff48SMatthias Springer return failure(); 263bd6a2452SVivian 264bd6a2452SVivian scf::ForOp partialIteration; 265bd6a2452SVivian // The case for peeling the first iteration of the loop. 266bd6a2452SVivian if (peelFront) { 267bd6a2452SVivian if (failed( 268bd6a2452SVivian peelForLoopFirstIteration(rewriter, forOp, partialIteration))) { 269bd6a2452SVivian return failure(); 270bd6a2452SVivian } 271bd6a2452SVivian } else { 272bc194a5bSMatthias Springer if (skipPartial) { 2730f3544d1SMatthias Springer // No peeling of loops inside the partial iteration of another peeled 2740f3544d1SMatthias Springer // loop. 275bc194a5bSMatthias Springer Operation *op = forOp.getOperation(); 2760f3544d1SMatthias Springer while ((op = op->getParentOfType<scf::ForOp>())) { 277bc194a5bSMatthias Springer if (op->hasAttr(kPartialIterationLabel)) 278bc194a5bSMatthias Springer return failure(); 279bc194a5bSMatthias Springer } 280bc194a5bSMatthias Springer } 281bc194a5bSMatthias Springer // Apply loop peeling. 282bd6a2452SVivian if (failed( 283bd6a2452SVivian peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration))) 2843a41ff48SMatthias Springer return failure(); 285bd6a2452SVivian } 286bd6a2452SVivian 2873a41ff48SMatthias Springer // Apply label, so that the same loop is not rewritten a second time. 2885fcf907bSMatthias Springer rewriter.modifyOpInPlace(partialIteration, [&]() { 2890f3544d1SMatthias Springer partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 29061f37758SMatthias Springer partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 29161f37758SMatthias Springer }); 2925fcf907bSMatthias Springer rewriter.modifyOpInPlace(forOp, [&]() { 2933a41ff48SMatthias Springer forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 2943a41ff48SMatthias Springer }); 2953a41ff48SMatthias Springer return success(); 2963a41ff48SMatthias Springer } 297bc194a5bSMatthias Springer 298bd6a2452SVivian // If set to true, the first iteration of the loop will be peeled. Otherwise, 299bd6a2452SVivian // the unevenly divisible loop will be peeled at the end. 300bd6a2452SVivian bool peelFront; 301bd6a2452SVivian 302bc194a5bSMatthias Springer /// If set to true, loops inside partial iterations of another peeled loop 303bc194a5bSMatthias Springer /// are not peeled. This reduces the size of the generated code. Partial 304bc194a5bSMatthias Springer /// iterations are not usually performance critical. 305bc194a5bSMatthias Springer /// Note: Takes into account the entire chain of parent operations, not just 306bc194a5bSMatthias Springer /// the direct parent. 307bc194a5bSMatthias Springer bool skipPartial; 3083a41ff48SMatthias Springer }; 3093a41ff48SMatthias Springer } // namespace 3103a41ff48SMatthias Springer 3114bcd08ebSStephan Herhut namespace { 312039b969bSMichele Scuttari struct ParallelLoopSpecialization 31367d0d7acSMichele Scuttari : public impl::SCFParallelLoopSpecializationBase< 31467d0d7acSMichele Scuttari ParallelLoopSpecialization> { 31541574554SRiver Riddle void runOnOperation() override { 31654998986SStella Laurenzo getOperation()->walk( 3174bcd08ebSStephan Herhut [](ParallelOp op) { specializeParallelLoopForUnrolling(op); }); 3184bcd08ebSStephan Herhut } 3194bcd08ebSStephan Herhut }; 3204bcd08ebSStephan Herhut 321039b969bSMichele Scuttari struct ForLoopSpecialization 32267d0d7acSMichele Scuttari : public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> { 32341574554SRiver Riddle void runOnOperation() override { 32454998986SStella Laurenzo getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); }); 3254bcd08ebSStephan Herhut } 3264bcd08ebSStephan Herhut }; 3273a41ff48SMatthias Springer 32867d0d7acSMichele Scuttari struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> { 32941574554SRiver Riddle void runOnOperation() override { 33054998986SStella Laurenzo auto *parentOp = getOperation(); 33154998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext(); 3323a41ff48SMatthias Springer RewritePatternSet patterns(ctx); 333bd6a2452SVivian patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial); 334*09dfc571SJacques Pienaar (void)applyPatternsGreedily(parentOp, std::move(patterns)); 3353a41ff48SMatthias Springer 336bc194a5bSMatthias Springer // Drop the markers. 33754998986SStella Laurenzo parentOp->walk([](Operation *op) { 338bc194a5bSMatthias Springer op->removeAttr(kPeeledLoopLabel); 339bc194a5bSMatthias Springer op->removeAttr(kPartialIterationLabel); 340bc194a5bSMatthias Springer }); 3413a41ff48SMatthias Springer } 3423a41ff48SMatthias Springer }; 3434bcd08ebSStephan Herhut } // namespace 344039b969bSMichele Scuttari 345039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() { 346039b969bSMichele Scuttari return std::make_unique<ParallelLoopSpecialization>(); 347039b969bSMichele Scuttari } 348039b969bSMichele Scuttari 349039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() { 350039b969bSMichele Scuttari return std::make_unique<ForLoopSpecialization>(); 351039b969bSMichele Scuttari } 352039b969bSMichele Scuttari 353039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopPeelingPass() { 354039b969bSMichele Scuttari return std::make_unique<ForLoopPeeling>(); 355039b969bSMichele Scuttari } 356