xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (revision f6f1ab9d90252f9b943e77a64e30a3d26ef7cbbb)
14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
24bcd08ebSStephan Herhut //
34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64bcd08ebSStephan Herhut //
74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
84bcd08ebSStephan Herhut //
94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and
104bcd08ebSStephan Herhut // vectorization.
114bcd08ebSStephan Herhut //
124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
134bcd08ebSStephan Herhut 
1467d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
174bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h"
244d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h"
263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h"
284bcd08ebSStephan Herhut 
2967d0d7acSMichele Scuttari namespace mlir {
3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPPEELING
3167d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
3367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
364bcd08ebSStephan Herhut using namespace mlir;
374c48f016SMatthias Springer using namespace mlir::affine;
384bcd08ebSStephan Herhut using scf::ForOp;
394bcd08ebSStephan Herhut using scf::ParallelOp;
404bcd08ebSStephan Herhut 
414bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
424bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This
434bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and
444bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized.
454bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) {
464bcd08ebSStephan Herhut   SmallVector<int64_t, 2> constantIndices;
47c0342a2dSJacques Pienaar   constantIndices.reserve(op.getUpperBound().size());
48c0342a2dSJacques Pienaar   for (auto bound : op.getUpperBound()) {
494bcd08ebSStephan Herhut     auto minOp = bound.getDefiningOp<AffineMinOp>();
504bcd08ebSStephan Herhut     if (!minOp)
514bcd08ebSStephan Herhut       return;
524bcd08ebSStephan Herhut     int64_t minConstant = std::numeric_limits<int64_t>::max();
5304235d07SJacques Pienaar     for (AffineExpr expr : minOp.getMap().getResults()) {
541609f1c2Slong.chen       if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
554bcd08ebSStephan Herhut         minConstant = std::min(minConstant, constantIndex.getValue());
564bcd08ebSStephan Herhut     }
574bcd08ebSStephan Herhut     if (minConstant == std::numeric_limits<int64_t>::max())
584bcd08ebSStephan Herhut       return;
594bcd08ebSStephan Herhut     constantIndices.push_back(minConstant);
604bcd08ebSStephan Herhut   }
614bcd08ebSStephan Herhut 
624bcd08ebSStephan Herhut   OpBuilder b(op);
634d67b278SJeff Niu   IRMapping map;
644bcd08ebSStephan Herhut   Value cond;
65c0342a2dSJacques Pienaar   for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
66a54f4eaeSMogball     Value constant =
67a54f4eaeSMogball         b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
68a54f4eaeSMogball     Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
694bcd08ebSStephan Herhut                                         std::get<0>(bound), constant);
70a54f4eaeSMogball     cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
714bcd08ebSStephan Herhut     map.map(std::get<0>(bound), constant);
724bcd08ebSStephan Herhut   }
734bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
744bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
754bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
764bcd08ebSStephan Herhut   op.erase();
774bcd08ebSStephan Herhut }
784bcd08ebSStephan Herhut 
794bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into
804bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is
814bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that
824bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized.
834bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) {
84c0342a2dSJacques Pienaar   auto bound = op.getUpperBound();
854bcd08ebSStephan Herhut   auto minOp = bound.getDefiningOp<AffineMinOp>();
864bcd08ebSStephan Herhut   if (!minOp)
874bcd08ebSStephan Herhut     return;
884bcd08ebSStephan Herhut   int64_t minConstant = std::numeric_limits<int64_t>::max();
8904235d07SJacques Pienaar   for (AffineExpr expr : minOp.getMap().getResults()) {
901609f1c2Slong.chen     if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
914bcd08ebSStephan Herhut       minConstant = std::min(minConstant, constantIndex.getValue());
924bcd08ebSStephan Herhut   }
934bcd08ebSStephan Herhut   if (minConstant == std::numeric_limits<int64_t>::max())
944bcd08ebSStephan Herhut     return;
954bcd08ebSStephan Herhut 
964bcd08ebSStephan Herhut   OpBuilder b(op);
974d67b278SJeff Niu   IRMapping map;
98a54f4eaeSMogball   Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
99a54f4eaeSMogball   Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
100a54f4eaeSMogball                                        bound, constant);
1014bcd08ebSStephan Herhut   map.map(bound, constant);
1024bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
1034bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
1044bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
1054bcd08ebSStephan Herhut   op.erase();
1064bcd08ebSStephan Herhut }
1074bcd08ebSStephan Herhut 
1083a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly
1093a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed
1103a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any).
1118e8b70aaSMatthias Springer ///
1128e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new
1138e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the
1148e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if.
1158e8b70aaSMatthias Springer ///
1168e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary
1178e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`.
1188e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not.
1190f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
1200f3544d1SMatthias Springer                                  ForOp &partialIteration, Value &splitBound) {
1213a41ff48SMatthias Springer   RewriterBase::InsertionGuard guard(b);
122c0342a2dSJacques Pienaar   auto lbInt = getConstantIntValue(forOp.getLowerBound());
123c0342a2dSJacques Pienaar   auto ubInt = getConstantIntValue(forOp.getUpperBound());
124c0342a2dSJacques Pienaar   auto stepInt = getConstantIntValue(forOp.getStep());
1253a41ff48SMatthias Springer 
126*f6f1ab9dSFelix Schneider   // No specialization necessary if step size is 1. Also bail out in case of an
127*f6f1ab9dSFelix Schneider   // invalid zero or negative step which might have happened during folding.
128*f6f1ab9dSFelix Schneider   if (stepInt && *stepInt <= 1)
1293a41ff48SMatthias Springer     return failure();
1303a41ff48SMatthias Springer 
13196901f1bSMatthias Springer   // No specialization necessary if step already divides upper bound evenly.
13296901f1bSMatthias Springer   // Fast path: lb, ub and step are constants.
13396901f1bSMatthias Springer   if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
13496901f1bSMatthias Springer     return failure();
13596901f1bSMatthias Springer   // Slow path: Examine the ops that define lb, ub and step.
1360c360829SMatthias Springer   AffineExpr sym0, sym1, sym2;
1370c360829SMatthias Springer   bindSymbols(b.getContext(), sym0, sym1, sym2);
13896901f1bSMatthias Springer   SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
13996901f1bSMatthias Springer                               forOp.getStep()};
14096901f1bSMatthias Springer   AffineMap map = AffineMap::get(0, 3, {(sym1 - sym0) % sym2});
14196901f1bSMatthias Springer   affine::fullyComposeAffineMapAndOperands(&map, &operands);
14296901f1bSMatthias Springer   if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(0)))
14396901f1bSMatthias Springer     if (constExpr.getValue() == 0)
14496901f1bSMatthias Springer       return failure();
14596901f1bSMatthias Springer 
1463a41ff48SMatthias Springer   // New upper bound: %ub - (%ub - %lb) mod %step
1470c360829SMatthias Springer   auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
148767974f3SMatthias Springer   b.setInsertionPoint(forOp);
14996901f1bSMatthias Springer   auto loc = forOp.getLoc();
150c0342a2dSJacques Pienaar   splitBound = b.createOrFold<AffineApplyOp>(loc, modMap,
151c0342a2dSJacques Pienaar                                              ValueRange{forOp.getLowerBound(),
152c0342a2dSJacques Pienaar                                                         forOp.getUpperBound(),
153c0342a2dSJacques Pienaar                                                         forOp.getStep()});
1543a41ff48SMatthias Springer 
1550f3544d1SMatthias Springer   // Create ForOp for partial iteration.
1560f3544d1SMatthias Springer   b.setInsertionPointAfter(forOp);
1570f3544d1SMatthias Springer   partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
158c0342a2dSJacques Pienaar   partialIteration.getLowerBoundMutable().assign(splitBound);
15961f37758SMatthias Springer   b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults());
160c0342a2dSJacques Pienaar   partialIteration.getInitArgsMutable().assign(forOp->getResults());
1610f3544d1SMatthias Springer 
1623a41ff48SMatthias Springer   // Set new upper loop bound.
163c0342a2dSJacques Pienaar   b.updateRootInPlace(
164c0342a2dSJacques Pienaar       forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
1653a41ff48SMatthias Springer 
1663a41ff48SMatthias Springer   return success();
1673a41ff48SMatthias Springer }
1683a41ff48SMatthias Springer 
1690f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
1700f3544d1SMatthias Springer                                         ForOp partialIteration,
1710f3544d1SMatthias Springer                                         Value previousUb) {
1720f3544d1SMatthias Springer   Value mainIv = forOp.getInductionVar();
1730f3544d1SMatthias Springer   Value partialIv = partialIteration.getInductionVar();
174c0342a2dSJacques Pienaar   assert(forOp.getStep() == partialIteration.getStep() &&
1750f3544d1SMatthias Springer          "expected same step in main and partial loop");
176c0342a2dSJacques Pienaar   Value step = forOp.getStep();
1770f3544d1SMatthias Springer 
1783a5811a3SMatthias Springer   forOp.walk([&](Operation *affineOp) {
1793a5811a3SMatthias Springer     if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
1803a5811a3SMatthias Springer       return WalkResult::advance();
1813a5811a3SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb,
1823a5811a3SMatthias Springer                                      step,
183a9cff97fSMatthias Springer                                      /*insideLoop=*/true);
1843a5811a3SMatthias Springer     return WalkResult::advance();
185a9cff97fSMatthias Springer   });
1863a5811a3SMatthias Springer   partialIteration.walk([&](Operation *affineOp) {
1873a5811a3SMatthias Springer     if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
1883a5811a3SMatthias Springer       return WalkResult::advance();
1893a5811a3SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
1903a5811a3SMatthias Springer                                      step, /*insideLoop=*/false);
1913a5811a3SMatthias Springer     return WalkResult::advance();
192a9cff97fSMatthias Springer   });
1938e8b70aaSMatthias Springer }
1948e8b70aaSMatthias Springer 
195bb2ae985SNicolas Vasilache LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
196bc194a5bSMatthias Springer                                                       ForOp forOp,
1970f3544d1SMatthias Springer                                                       ForOp &partialIteration) {
198c0342a2dSJacques Pienaar   Value previousUb = forOp.getUpperBound();
1998e8b70aaSMatthias Springer   Value splitBound;
2000f3544d1SMatthias Springer   if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound)))
2018e8b70aaSMatthias Springer     return failure();
2028e8b70aaSMatthias Springer 
203a9cff97fSMatthias Springer   // Rewrite affine.min and affine.max ops.
2043a5811a3SMatthias Springer   rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb);
2058e8b70aaSMatthias Springer 
2068e8b70aaSMatthias Springer   return success();
2078e8b70aaSMatthias Springer }
2088e8b70aaSMatthias Springer 
209bd6a2452SVivian /// When the `peelFront` option is set as true, the first iteration of the loop
210bd6a2452SVivian /// is peeled off. This function rewrites the original scf::ForOp as two
211bd6a2452SVivian /// scf::ForOp Ops, the first scf::ForOp corresponds to the first iteration of
212bd6a2452SVivian /// the loop which can be canonicalized away in the following optimization. The
213bd6a2452SVivian /// second loop Op contains the remaining iteration, and the new lower bound is
214bd6a2452SVivian /// the original lower bound plus the number of steps.
215bd6a2452SVivian LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
216bd6a2452SVivian                                                    ForOp &firstIteration) {
217bd6a2452SVivian   RewriterBase::InsertionGuard guard(b);
218bd6a2452SVivian   auto lbInt = getConstantIntValue(forOp.getLowerBound());
219bd6a2452SVivian   auto ubInt = getConstantIntValue(forOp.getUpperBound());
220bd6a2452SVivian   auto stepInt = getConstantIntValue(forOp.getStep());
221bd6a2452SVivian 
222bd6a2452SVivian   // Peeling is not needed if there is one or less iteration.
223bd6a2452SVivian   if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) / *stepInt <= 1)
224bd6a2452SVivian     return failure();
225bd6a2452SVivian 
226bd6a2452SVivian   AffineExpr lbSymbol, stepSymbol;
227bd6a2452SVivian   bindSymbols(b.getContext(), lbSymbol, stepSymbol);
228bd6a2452SVivian 
229bd6a2452SVivian   // New lower bound for main loop: %lb + %step
230bd6a2452SVivian   auto ubMap = AffineMap::get(0, 2, {lbSymbol + stepSymbol});
231bd6a2452SVivian   b.setInsertionPoint(forOp);
232bd6a2452SVivian   auto loc = forOp.getLoc();
233bd6a2452SVivian   Value splitBound = b.createOrFold<AffineApplyOp>(
234bd6a2452SVivian       loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
235bd6a2452SVivian 
236bd6a2452SVivian   // Peel the first iteration.
237bd6a2452SVivian   IRMapping map;
238bd6a2452SVivian   map.map(forOp.getUpperBound(), splitBound);
239bd6a2452SVivian   firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
240bd6a2452SVivian 
241bd6a2452SVivian   // Update main loop with new lower bound.
242bd6a2452SVivian   b.updateRootInPlace(forOp, [&]() {
243bd6a2452SVivian     forOp.getInitArgsMutable().assign(firstIteration->getResults());
244bd6a2452SVivian     forOp.getLowerBoundMutable().assign(splitBound);
245bd6a2452SVivian   });
246bd6a2452SVivian 
247bd6a2452SVivian   return success();
248bd6a2452SVivian }
249bd6a2452SVivian 
2503a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
251bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
2523a41ff48SMatthias Springer 
2533a41ff48SMatthias Springer namespace {
2543a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
255bd6a2452SVivian   ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial)
256bd6a2452SVivian       : OpRewritePattern<ForOp>(ctx), peelFront(peelFront),
257bd6a2452SVivian         skipPartial(skipPartial) {}
2583a41ff48SMatthias Springer 
2593a41ff48SMatthias Springer   LogicalResult matchAndRewrite(ForOp forOp,
2603a41ff48SMatthias Springer                                 PatternRewriter &rewriter) const override {
261bc194a5bSMatthias Springer     // Do not peel already peeled loops.
2623a41ff48SMatthias Springer     if (forOp->hasAttr(kPeeledLoopLabel))
2633a41ff48SMatthias Springer       return failure();
264bd6a2452SVivian 
265bd6a2452SVivian     scf::ForOp partialIteration;
266bd6a2452SVivian     // The case for peeling the first iteration of the loop.
267bd6a2452SVivian     if (peelFront) {
268bd6a2452SVivian       if (failed(
269bd6a2452SVivian               peelForLoopFirstIteration(rewriter, forOp, partialIteration))) {
270bd6a2452SVivian         return failure();
271bd6a2452SVivian       }
272bd6a2452SVivian     } else {
273bc194a5bSMatthias Springer       if (skipPartial) {
2740f3544d1SMatthias Springer         // No peeling of loops inside the partial iteration of another peeled
2750f3544d1SMatthias Springer         // loop.
276bc194a5bSMatthias Springer         Operation *op = forOp.getOperation();
2770f3544d1SMatthias Springer         while ((op = op->getParentOfType<scf::ForOp>())) {
278bc194a5bSMatthias Springer           if (op->hasAttr(kPartialIterationLabel))
279bc194a5bSMatthias Springer             return failure();
280bc194a5bSMatthias Springer         }
281bc194a5bSMatthias Springer       }
282bc194a5bSMatthias Springer       // Apply loop peeling.
283bd6a2452SVivian       if (failed(
284bd6a2452SVivian               peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
2853a41ff48SMatthias Springer         return failure();
286bd6a2452SVivian     }
287bd6a2452SVivian 
2883a41ff48SMatthias Springer     // Apply label, so that the same loop is not rewritten a second time.
28961f37758SMatthias Springer     rewriter.updateRootInPlace(partialIteration, [&]() {
2900f3544d1SMatthias Springer       partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
29161f37758SMatthias Springer       partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
29261f37758SMatthias Springer     });
2933a41ff48SMatthias Springer     rewriter.updateRootInPlace(forOp, [&]() {
2943a41ff48SMatthias Springer       forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2953a41ff48SMatthias Springer     });
2963a41ff48SMatthias Springer     return success();
2973a41ff48SMatthias Springer   }
298bc194a5bSMatthias Springer 
299bd6a2452SVivian   // If set to true, the first iteration of the loop will be peeled. Otherwise,
300bd6a2452SVivian   // the unevenly divisible loop will be peeled at the end.
301bd6a2452SVivian   bool peelFront;
302bd6a2452SVivian 
303bc194a5bSMatthias Springer   /// If set to true, loops inside partial iterations of another peeled loop
304bc194a5bSMatthias Springer   /// are not peeled. This reduces the size of the generated code. Partial
305bc194a5bSMatthias Springer   /// iterations are not usually performance critical.
306bc194a5bSMatthias Springer   /// Note: Takes into account the entire chain of parent operations, not just
307bc194a5bSMatthias Springer   /// the direct parent.
308bc194a5bSMatthias Springer   bool skipPartial;
3093a41ff48SMatthias Springer };
3103a41ff48SMatthias Springer } // namespace
3113a41ff48SMatthias Springer 
3124bcd08ebSStephan Herhut namespace {
313039b969bSMichele Scuttari struct ParallelLoopSpecialization
31467d0d7acSMichele Scuttari     : public impl::SCFParallelLoopSpecializationBase<
31567d0d7acSMichele Scuttari           ParallelLoopSpecialization> {
31641574554SRiver Riddle   void runOnOperation() override {
31754998986SStella Laurenzo     getOperation()->walk(
3184bcd08ebSStephan Herhut         [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
3194bcd08ebSStephan Herhut   }
3204bcd08ebSStephan Herhut };
3214bcd08ebSStephan Herhut 
322039b969bSMichele Scuttari struct ForLoopSpecialization
32367d0d7acSMichele Scuttari     : public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> {
32441574554SRiver Riddle   void runOnOperation() override {
32554998986SStella Laurenzo     getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); });
3264bcd08ebSStephan Herhut   }
3274bcd08ebSStephan Herhut };
3283a41ff48SMatthias Springer 
32967d0d7acSMichele Scuttari struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
33041574554SRiver Riddle   void runOnOperation() override {
33154998986SStella Laurenzo     auto *parentOp = getOperation();
33254998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
3333a41ff48SMatthias Springer     RewritePatternSet patterns(ctx);
334bd6a2452SVivian     patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
33554998986SStella Laurenzo     (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
3363a41ff48SMatthias Springer 
337bc194a5bSMatthias Springer     // Drop the markers.
33854998986SStella Laurenzo     parentOp->walk([](Operation *op) {
339bc194a5bSMatthias Springer       op->removeAttr(kPeeledLoopLabel);
340bc194a5bSMatthias Springer       op->removeAttr(kPartialIterationLabel);
341bc194a5bSMatthias Springer     });
3423a41ff48SMatthias Springer   }
3433a41ff48SMatthias Springer };
3444bcd08ebSStephan Herhut } // namespace
345039b969bSMichele Scuttari 
346039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
347039b969bSMichele Scuttari   return std::make_unique<ParallelLoopSpecialization>();
348039b969bSMichele Scuttari }
349039b969bSMichele Scuttari 
350039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
351039b969bSMichele Scuttari   return std::make_unique<ForLoopSpecialization>();
352039b969bSMichele Scuttari }
353039b969bSMichele Scuttari 
354039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
355039b969bSMichele Scuttari   return std::make_unique<ForLoopPeeling>();
356039b969bSMichele Scuttari }
357