xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
24bcd08ebSStephan Herhut //
34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64bcd08ebSStephan Herhut //
74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
84bcd08ebSStephan Herhut //
94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and
104bcd08ebSStephan Herhut // vectorization.
114bcd08ebSStephan Herhut //
124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
134bcd08ebSStephan Herhut 
1467d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
174bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h"
244d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h"
263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h"
284bcd08ebSStephan Herhut 
2967d0d7acSMichele Scuttari namespace mlir {
3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPPEELING
3167d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
3367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
364bcd08ebSStephan Herhut using namespace mlir;
374c48f016SMatthias Springer using namespace mlir::affine;
384bcd08ebSStephan Herhut using scf::ForOp;
394bcd08ebSStephan Herhut using scf::ParallelOp;
404bcd08ebSStephan Herhut 
414bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
424bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This
434bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and
444bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized.
454bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) {
464bcd08ebSStephan Herhut   SmallVector<int64_t, 2> constantIndices;
47c0342a2dSJacques Pienaar   constantIndices.reserve(op.getUpperBound().size());
48c0342a2dSJacques Pienaar   for (auto bound : op.getUpperBound()) {
494bcd08ebSStephan Herhut     auto minOp = bound.getDefiningOp<AffineMinOp>();
504bcd08ebSStephan Herhut     if (!minOp)
514bcd08ebSStephan Herhut       return;
524bcd08ebSStephan Herhut     int64_t minConstant = std::numeric_limits<int64_t>::max();
5304235d07SJacques Pienaar     for (AffineExpr expr : minOp.getMap().getResults()) {
541609f1c2Slong.chen       if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
554bcd08ebSStephan Herhut         minConstant = std::min(minConstant, constantIndex.getValue());
564bcd08ebSStephan Herhut     }
574bcd08ebSStephan Herhut     if (minConstant == std::numeric_limits<int64_t>::max())
584bcd08ebSStephan Herhut       return;
594bcd08ebSStephan Herhut     constantIndices.push_back(minConstant);
604bcd08ebSStephan Herhut   }
614bcd08ebSStephan Herhut 
624bcd08ebSStephan Herhut   OpBuilder b(op);
634d67b278SJeff Niu   IRMapping map;
644bcd08ebSStephan Herhut   Value cond;
65c0342a2dSJacques Pienaar   for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
66a54f4eaeSMogball     Value constant =
67a54f4eaeSMogball         b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
68a54f4eaeSMogball     Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
694bcd08ebSStephan Herhut                                         std::get<0>(bound), constant);
70a54f4eaeSMogball     cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
714bcd08ebSStephan Herhut     map.map(std::get<0>(bound), constant);
724bcd08ebSStephan Herhut   }
734bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
744bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
754bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
764bcd08ebSStephan Herhut   op.erase();
774bcd08ebSStephan Herhut }
784bcd08ebSStephan Herhut 
794bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into
804bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is
814bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that
824bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized.
834bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) {
84c0342a2dSJacques Pienaar   auto bound = op.getUpperBound();
854bcd08ebSStephan Herhut   auto minOp = bound.getDefiningOp<AffineMinOp>();
864bcd08ebSStephan Herhut   if (!minOp)
874bcd08ebSStephan Herhut     return;
884bcd08ebSStephan Herhut   int64_t minConstant = std::numeric_limits<int64_t>::max();
8904235d07SJacques Pienaar   for (AffineExpr expr : minOp.getMap().getResults()) {
901609f1c2Slong.chen     if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
914bcd08ebSStephan Herhut       minConstant = std::min(minConstant, constantIndex.getValue());
924bcd08ebSStephan Herhut   }
934bcd08ebSStephan Herhut   if (minConstant == std::numeric_limits<int64_t>::max())
944bcd08ebSStephan Herhut     return;
954bcd08ebSStephan Herhut 
964bcd08ebSStephan Herhut   OpBuilder b(op);
974d67b278SJeff Niu   IRMapping map;
98a54f4eaeSMogball   Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
99a54f4eaeSMogball   Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
100a54f4eaeSMogball                                        bound, constant);
1014bcd08ebSStephan Herhut   map.map(bound, constant);
1024bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
1034bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
1044bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
1054bcd08ebSStephan Herhut   op.erase();
1064bcd08ebSStephan Herhut }
1074bcd08ebSStephan Herhut 
1083a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly
1093a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed
1103a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any).
1118e8b70aaSMatthias Springer ///
1128e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new
1138e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the
1148e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if.
1158e8b70aaSMatthias Springer ///
1168e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary
1178e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`.
1188e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not.
1190f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
1200f3544d1SMatthias Springer                                  ForOp &partialIteration, Value &splitBound) {
1213a41ff48SMatthias Springer   RewriterBase::InsertionGuard guard(b);
122c0342a2dSJacques Pienaar   auto lbInt = getConstantIntValue(forOp.getLowerBound());
123c0342a2dSJacques Pienaar   auto ubInt = getConstantIntValue(forOp.getUpperBound());
124c0342a2dSJacques Pienaar   auto stepInt = getConstantIntValue(forOp.getStep());
1253a41ff48SMatthias Springer 
126f6f1ab9dSFelix Schneider   // No specialization necessary if step size is 1. Also bail out in case of an
127f6f1ab9dSFelix Schneider   // invalid zero or negative step which might have happened during folding.
128f6f1ab9dSFelix Schneider   if (stepInt && *stepInt <= 1)
1293a41ff48SMatthias Springer     return failure();
1303a41ff48SMatthias Springer 
13196901f1bSMatthias Springer   // No specialization necessary if step already divides upper bound evenly.
13296901f1bSMatthias Springer   // Fast path: lb, ub and step are constants.
13396901f1bSMatthias Springer   if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
13496901f1bSMatthias Springer     return failure();
13596901f1bSMatthias Springer   // Slow path: Examine the ops that define lb, ub and step.
1360c360829SMatthias Springer   AffineExpr sym0, sym1, sym2;
1370c360829SMatthias Springer   bindSymbols(b.getContext(), sym0, sym1, sym2);
13896901f1bSMatthias Springer   SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
13996901f1bSMatthias Springer                               forOp.getStep()};
14096901f1bSMatthias Springer   AffineMap map = AffineMap::get(0, 3, {(sym1 - sym0) % sym2});
14196901f1bSMatthias Springer   affine::fullyComposeAffineMapAndOperands(&map, &operands);
14296901f1bSMatthias Springer   if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(0)))
14396901f1bSMatthias Springer     if (constExpr.getValue() == 0)
14496901f1bSMatthias Springer       return failure();
14596901f1bSMatthias Springer 
1463a41ff48SMatthias Springer   // New upper bound: %ub - (%ub - %lb) mod %step
1470c360829SMatthias Springer   auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
148767974f3SMatthias Springer   b.setInsertionPoint(forOp);
14996901f1bSMatthias Springer   auto loc = forOp.getLoc();
150c0342a2dSJacques Pienaar   splitBound = b.createOrFold<AffineApplyOp>(loc, modMap,
151c0342a2dSJacques Pienaar                                              ValueRange{forOp.getLowerBound(),
152c0342a2dSJacques Pienaar                                                         forOp.getUpperBound(),
153c0342a2dSJacques Pienaar                                                         forOp.getStep()});
1543a41ff48SMatthias Springer 
1550f3544d1SMatthias Springer   // Create ForOp for partial iteration.
1560f3544d1SMatthias Springer   b.setInsertionPointAfter(forOp);
1570f3544d1SMatthias Springer   partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
158c0342a2dSJacques Pienaar   partialIteration.getLowerBoundMutable().assign(splitBound);
15961f37758SMatthias Springer   b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults());
160c0342a2dSJacques Pienaar   partialIteration.getInitArgsMutable().assign(forOp->getResults());
1610f3544d1SMatthias Springer 
1623a41ff48SMatthias Springer   // Set new upper loop bound.
1635fcf907bSMatthias Springer   b.modifyOpInPlace(forOp,
1645fcf907bSMatthias Springer                     [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
1653a41ff48SMatthias Springer 
1663a41ff48SMatthias Springer   return success();
1673a41ff48SMatthias Springer }
1683a41ff48SMatthias Springer 
1690f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
1700f3544d1SMatthias Springer                                         ForOp partialIteration,
1710f3544d1SMatthias Springer                                         Value previousUb) {
1720f3544d1SMatthias Springer   Value mainIv = forOp.getInductionVar();
1730f3544d1SMatthias Springer   Value partialIv = partialIteration.getInductionVar();
174c0342a2dSJacques Pienaar   assert(forOp.getStep() == partialIteration.getStep() &&
1750f3544d1SMatthias Springer          "expected same step in main and partial loop");
176c0342a2dSJacques Pienaar   Value step = forOp.getStep();
1770f3544d1SMatthias Springer 
1783a5811a3SMatthias Springer   forOp.walk([&](Operation *affineOp) {
1793a5811a3SMatthias Springer     if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
1803a5811a3SMatthias Springer       return WalkResult::advance();
1813a5811a3SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb,
1823a5811a3SMatthias Springer                                      step,
183a9cff97fSMatthias Springer                                      /*insideLoop=*/true);
1843a5811a3SMatthias Springer     return WalkResult::advance();
185a9cff97fSMatthias Springer   });
1863a5811a3SMatthias Springer   partialIteration.walk([&](Operation *affineOp) {
1873a5811a3SMatthias Springer     if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
1883a5811a3SMatthias Springer       return WalkResult::advance();
1893a5811a3SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
1903a5811a3SMatthias Springer                                      step, /*insideLoop=*/false);
1913a5811a3SMatthias Springer     return WalkResult::advance();
192a9cff97fSMatthias Springer   });
1938e8b70aaSMatthias Springer }
1948e8b70aaSMatthias Springer 
195bb2ae985SNicolas Vasilache LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
196bc194a5bSMatthias Springer                                                       ForOp forOp,
1970f3544d1SMatthias Springer                                                       ForOp &partialIteration) {
198c0342a2dSJacques Pienaar   Value previousUb = forOp.getUpperBound();
1998e8b70aaSMatthias Springer   Value splitBound;
2000f3544d1SMatthias Springer   if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound)))
2018e8b70aaSMatthias Springer     return failure();
2028e8b70aaSMatthias Springer 
203a9cff97fSMatthias Springer   // Rewrite affine.min and affine.max ops.
2043a5811a3SMatthias Springer   rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb);
2058e8b70aaSMatthias Springer 
2068e8b70aaSMatthias Springer   return success();
2078e8b70aaSMatthias Springer }
2088e8b70aaSMatthias Springer 
209a9c417c2SHugo Trachino /// Rewrites the original scf::ForOp as two scf::ForOp Ops, the first
210a9c417c2SHugo Trachino /// scf::ForOp corresponds to the first iteration of the loop which can be
211a9c417c2SHugo Trachino /// canonicalized away in the following optimizations. The second loop Op
212a9c417c2SHugo Trachino /// contains the remaining iterations, with a lower bound updated as the
213a9c417c2SHugo Trachino /// original lower bound plus the step (i.e. skips the first iteration).
214bd6a2452SVivian LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
215bd6a2452SVivian                                                    ForOp &firstIteration) {
216bd6a2452SVivian   RewriterBase::InsertionGuard guard(b);
217bd6a2452SVivian   auto lbInt = getConstantIntValue(forOp.getLowerBound());
218bd6a2452SVivian   auto ubInt = getConstantIntValue(forOp.getUpperBound());
219bd6a2452SVivian   auto stepInt = getConstantIntValue(forOp.getStep());
220bd6a2452SVivian 
221bd6a2452SVivian   // Peeling is not needed if there is one or less iteration.
222a9d1feadSVivian   if (lbInt && ubInt && stepInt && ceil(float(*ubInt - *lbInt) / *stepInt) <= 1)
223bd6a2452SVivian     return failure();
224bd6a2452SVivian 
225bd6a2452SVivian   AffineExpr lbSymbol, stepSymbol;
226bd6a2452SVivian   bindSymbols(b.getContext(), lbSymbol, stepSymbol);
227bd6a2452SVivian 
228bd6a2452SVivian   // New lower bound for main loop: %lb + %step
229bd6a2452SVivian   auto ubMap = AffineMap::get(0, 2, {lbSymbol + stepSymbol});
230bd6a2452SVivian   b.setInsertionPoint(forOp);
231bd6a2452SVivian   auto loc = forOp.getLoc();
232bd6a2452SVivian   Value splitBound = b.createOrFold<AffineApplyOp>(
233bd6a2452SVivian       loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
234bd6a2452SVivian 
235bd6a2452SVivian   // Peel the first iteration.
236bd6a2452SVivian   IRMapping map;
237bd6a2452SVivian   map.map(forOp.getUpperBound(), splitBound);
238bd6a2452SVivian   firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
239bd6a2452SVivian 
240bd6a2452SVivian   // Update main loop with new lower bound.
2415fcf907bSMatthias Springer   b.modifyOpInPlace(forOp, [&]() {
242bd6a2452SVivian     forOp.getInitArgsMutable().assign(firstIteration->getResults());
243bd6a2452SVivian     forOp.getLowerBoundMutable().assign(splitBound);
244bd6a2452SVivian   });
245bd6a2452SVivian 
246bd6a2452SVivian   return success();
247bd6a2452SVivian }
248bd6a2452SVivian 
2493a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
250bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
2513a41ff48SMatthias Springer 
2523a41ff48SMatthias Springer namespace {
2533a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
254bd6a2452SVivian   ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial)
255bd6a2452SVivian       : OpRewritePattern<ForOp>(ctx), peelFront(peelFront),
256bd6a2452SVivian         skipPartial(skipPartial) {}
2573a41ff48SMatthias Springer 
2583a41ff48SMatthias Springer   LogicalResult matchAndRewrite(ForOp forOp,
2593a41ff48SMatthias Springer                                 PatternRewriter &rewriter) const override {
260bc194a5bSMatthias Springer     // Do not peel already peeled loops.
2613a41ff48SMatthias Springer     if (forOp->hasAttr(kPeeledLoopLabel))
2623a41ff48SMatthias Springer       return failure();
263bd6a2452SVivian 
264bd6a2452SVivian     scf::ForOp partialIteration;
265bd6a2452SVivian     // The case for peeling the first iteration of the loop.
266bd6a2452SVivian     if (peelFront) {
267bd6a2452SVivian       if (failed(
268bd6a2452SVivian               peelForLoopFirstIteration(rewriter, forOp, partialIteration))) {
269bd6a2452SVivian         return failure();
270bd6a2452SVivian       }
271bd6a2452SVivian     } else {
272bc194a5bSMatthias Springer       if (skipPartial) {
2730f3544d1SMatthias Springer         // No peeling of loops inside the partial iteration of another peeled
2740f3544d1SMatthias Springer         // loop.
275bc194a5bSMatthias Springer         Operation *op = forOp.getOperation();
2760f3544d1SMatthias Springer         while ((op = op->getParentOfType<scf::ForOp>())) {
277bc194a5bSMatthias Springer           if (op->hasAttr(kPartialIterationLabel))
278bc194a5bSMatthias Springer             return failure();
279bc194a5bSMatthias Springer         }
280bc194a5bSMatthias Springer       }
281bc194a5bSMatthias Springer       // Apply loop peeling.
282bd6a2452SVivian       if (failed(
283bd6a2452SVivian               peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
2843a41ff48SMatthias Springer         return failure();
285bd6a2452SVivian     }
286bd6a2452SVivian 
2873a41ff48SMatthias Springer     // Apply label, so that the same loop is not rewritten a second time.
2885fcf907bSMatthias Springer     rewriter.modifyOpInPlace(partialIteration, [&]() {
2890f3544d1SMatthias Springer       partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
29061f37758SMatthias Springer       partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
29161f37758SMatthias Springer     });
2925fcf907bSMatthias Springer     rewriter.modifyOpInPlace(forOp, [&]() {
2933a41ff48SMatthias Springer       forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2943a41ff48SMatthias Springer     });
2953a41ff48SMatthias Springer     return success();
2963a41ff48SMatthias Springer   }
297bc194a5bSMatthias Springer 
298bd6a2452SVivian   // If set to true, the first iteration of the loop will be peeled. Otherwise,
299bd6a2452SVivian   // the unevenly divisible loop will be peeled at the end.
300bd6a2452SVivian   bool peelFront;
301bd6a2452SVivian 
302bc194a5bSMatthias Springer   /// If set to true, loops inside partial iterations of another peeled loop
303bc194a5bSMatthias Springer   /// are not peeled. This reduces the size of the generated code. Partial
304bc194a5bSMatthias Springer   /// iterations are not usually performance critical.
305bc194a5bSMatthias Springer   /// Note: Takes into account the entire chain of parent operations, not just
306bc194a5bSMatthias Springer   /// the direct parent.
307bc194a5bSMatthias Springer   bool skipPartial;
3083a41ff48SMatthias Springer };
3093a41ff48SMatthias Springer } // namespace
3103a41ff48SMatthias Springer 
3114bcd08ebSStephan Herhut namespace {
312039b969bSMichele Scuttari struct ParallelLoopSpecialization
31367d0d7acSMichele Scuttari     : public impl::SCFParallelLoopSpecializationBase<
31467d0d7acSMichele Scuttari           ParallelLoopSpecialization> {
31541574554SRiver Riddle   void runOnOperation() override {
31654998986SStella Laurenzo     getOperation()->walk(
3174bcd08ebSStephan Herhut         [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
3184bcd08ebSStephan Herhut   }
3194bcd08ebSStephan Herhut };
3204bcd08ebSStephan Herhut 
321039b969bSMichele Scuttari struct ForLoopSpecialization
32267d0d7acSMichele Scuttari     : public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> {
32341574554SRiver Riddle   void runOnOperation() override {
32454998986SStella Laurenzo     getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); });
3254bcd08ebSStephan Herhut   }
3264bcd08ebSStephan Herhut };
3273a41ff48SMatthias Springer 
32867d0d7acSMichele Scuttari struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
32941574554SRiver Riddle   void runOnOperation() override {
33054998986SStella Laurenzo     auto *parentOp = getOperation();
33154998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
3323a41ff48SMatthias Springer     RewritePatternSet patterns(ctx);
333bd6a2452SVivian     patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
334*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(parentOp, std::move(patterns));
3353a41ff48SMatthias Springer 
336bc194a5bSMatthias Springer     // Drop the markers.
33754998986SStella Laurenzo     parentOp->walk([](Operation *op) {
338bc194a5bSMatthias Springer       op->removeAttr(kPeeledLoopLabel);
339bc194a5bSMatthias Springer       op->removeAttr(kPartialIterationLabel);
340bc194a5bSMatthias Springer     });
3413a41ff48SMatthias Springer   }
3423a41ff48SMatthias Springer };
3434bcd08ebSStephan Herhut } // namespace
344039b969bSMichele Scuttari 
345039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
346039b969bSMichele Scuttari   return std::make_unique<ParallelLoopSpecialization>();
347039b969bSMichele Scuttari }
348039b969bSMichele Scuttari 
349039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
350039b969bSMichele Scuttari   return std::make_unique<ForLoopSpecialization>();
351039b969bSMichele Scuttari }
352039b969bSMichele Scuttari 
353039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
354039b969bSMichele Scuttari   return std::make_unique<ForLoopPeeling>();
355039b969bSMichele Scuttari }
356