xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
24bcd08ebSStephan Herhut //
34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64bcd08ebSStephan Herhut //
74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
84bcd08ebSStephan Herhut //
94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and
104bcd08ebSStephan Herhut // vectorization.
114bcd08ebSStephan Herhut //
124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
134bcd08ebSStephan Herhut 
14*039b969bSMichele Scuttari #include "PassDetail.h"
15755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
164bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h"
17a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
19*039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h"
244bcd08ebSStephan Herhut #include "mlir/IR/BlockAndValueMapping.h"
253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h"
263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h"
284bcd08ebSStephan Herhut 
294bcd08ebSStephan Herhut using namespace mlir;
304bcd08ebSStephan Herhut using scf::ForOp;
314bcd08ebSStephan Herhut using scf::ParallelOp;
324bcd08ebSStephan Herhut 
334bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
344bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This
354bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and
364bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized.
374bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) {
384bcd08ebSStephan Herhut   SmallVector<int64_t, 2> constantIndices;
39c0342a2dSJacques Pienaar   constantIndices.reserve(op.getUpperBound().size());
40c0342a2dSJacques Pienaar   for (auto bound : op.getUpperBound()) {
414bcd08ebSStephan Herhut     auto minOp = bound.getDefiningOp<AffineMinOp>();
424bcd08ebSStephan Herhut     if (!minOp)
434bcd08ebSStephan Herhut       return;
444bcd08ebSStephan Herhut     int64_t minConstant = std::numeric_limits<int64_t>::max();
4504235d07SJacques Pienaar     for (AffineExpr expr : minOp.getMap().getResults()) {
464bcd08ebSStephan Herhut       if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
474bcd08ebSStephan Herhut         minConstant = std::min(minConstant, constantIndex.getValue());
484bcd08ebSStephan Herhut     }
494bcd08ebSStephan Herhut     if (minConstant == std::numeric_limits<int64_t>::max())
504bcd08ebSStephan Herhut       return;
514bcd08ebSStephan Herhut     constantIndices.push_back(minConstant);
524bcd08ebSStephan Herhut   }
534bcd08ebSStephan Herhut 
544bcd08ebSStephan Herhut   OpBuilder b(op);
554bcd08ebSStephan Herhut   BlockAndValueMapping map;
564bcd08ebSStephan Herhut   Value cond;
57c0342a2dSJacques Pienaar   for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
58a54f4eaeSMogball     Value constant =
59a54f4eaeSMogball         b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
60a54f4eaeSMogball     Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
614bcd08ebSStephan Herhut                                         std::get<0>(bound), constant);
62a54f4eaeSMogball     cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
634bcd08ebSStephan Herhut     map.map(std::get<0>(bound), constant);
644bcd08ebSStephan Herhut   }
654bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
664bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
674bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
684bcd08ebSStephan Herhut   op.erase();
694bcd08ebSStephan Herhut }
704bcd08ebSStephan Herhut 
714bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into
724bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is
734bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that
744bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized.
754bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) {
76c0342a2dSJacques Pienaar   auto bound = op.getUpperBound();
774bcd08ebSStephan Herhut   auto minOp = bound.getDefiningOp<AffineMinOp>();
784bcd08ebSStephan Herhut   if (!minOp)
794bcd08ebSStephan Herhut     return;
804bcd08ebSStephan Herhut   int64_t minConstant = std::numeric_limits<int64_t>::max();
8104235d07SJacques Pienaar   for (AffineExpr expr : minOp.getMap().getResults()) {
824bcd08ebSStephan Herhut     if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
834bcd08ebSStephan Herhut       minConstant = std::min(minConstant, constantIndex.getValue());
844bcd08ebSStephan Herhut   }
854bcd08ebSStephan Herhut   if (minConstant == std::numeric_limits<int64_t>::max())
864bcd08ebSStephan Herhut     return;
874bcd08ebSStephan Herhut 
884bcd08ebSStephan Herhut   OpBuilder b(op);
894bcd08ebSStephan Herhut   BlockAndValueMapping map;
90a54f4eaeSMogball   Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
91a54f4eaeSMogball   Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
92a54f4eaeSMogball                                        bound, constant);
934bcd08ebSStephan Herhut   map.map(bound, constant);
944bcd08ebSStephan Herhut   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
954bcd08ebSStephan Herhut   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
964bcd08ebSStephan Herhut   ifOp.getElseBodyBuilder().clone(*op.getOperation());
974bcd08ebSStephan Herhut   op.erase();
984bcd08ebSStephan Herhut }
994bcd08ebSStephan Herhut 
1003a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly
1013a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed
1023a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any).
1038e8b70aaSMatthias Springer ///
1048e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new
1058e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the
1068e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if.
1078e8b70aaSMatthias Springer ///
1088e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary
1098e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`.
1108e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not.
1110f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
1120f3544d1SMatthias Springer                                  ForOp &partialIteration, Value &splitBound) {
1133a41ff48SMatthias Springer   RewriterBase::InsertionGuard guard(b);
114c0342a2dSJacques Pienaar   auto lbInt = getConstantIntValue(forOp.getLowerBound());
115c0342a2dSJacques Pienaar   auto ubInt = getConstantIntValue(forOp.getUpperBound());
116c0342a2dSJacques Pienaar   auto stepInt = getConstantIntValue(forOp.getStep());
1173a41ff48SMatthias Springer 
1183a41ff48SMatthias Springer   // No specialization necessary if step already divides upper bound evenly.
1193a41ff48SMatthias Springer   if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
1203a41ff48SMatthias Springer     return failure();
1213a41ff48SMatthias Springer   // No specialization necessary if step size is 1.
1223a41ff48SMatthias Springer   if (stepInt == static_cast<int64_t>(1))
1233a41ff48SMatthias Springer     return failure();
1243a41ff48SMatthias Springer 
1253a41ff48SMatthias Springer   auto loc = forOp.getLoc();
1260c360829SMatthias Springer   AffineExpr sym0, sym1, sym2;
1270c360829SMatthias Springer   bindSymbols(b.getContext(), sym0, sym1, sym2);
1283a41ff48SMatthias Springer   // New upper bound: %ub - (%ub - %lb) mod %step
1290c360829SMatthias Springer   auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
130767974f3SMatthias Springer   b.setInsertionPoint(forOp);
131c0342a2dSJacques Pienaar   splitBound = b.createOrFold<AffineApplyOp>(loc, modMap,
132c0342a2dSJacques Pienaar                                              ValueRange{forOp.getLowerBound(),
133c0342a2dSJacques Pienaar                                                         forOp.getUpperBound(),
134c0342a2dSJacques Pienaar                                                         forOp.getStep()});
1353a41ff48SMatthias Springer 
1360f3544d1SMatthias Springer   // Create ForOp for partial iteration.
1370f3544d1SMatthias Springer   b.setInsertionPointAfter(forOp);
1380f3544d1SMatthias Springer   partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
139c0342a2dSJacques Pienaar   partialIteration.getLowerBoundMutable().assign(splitBound);
1400f3544d1SMatthias Springer   forOp.replaceAllUsesWith(partialIteration->getResults());
141c0342a2dSJacques Pienaar   partialIteration.getInitArgsMutable().assign(forOp->getResults());
1420f3544d1SMatthias Springer 
1433a41ff48SMatthias Springer   // Set new upper loop bound.
144c0342a2dSJacques Pienaar   b.updateRootInPlace(
145c0342a2dSJacques Pienaar       forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
1463a41ff48SMatthias Springer 
1473a41ff48SMatthias Springer   return success();
1483a41ff48SMatthias Springer }
1493a41ff48SMatthias Springer 
150a9cff97fSMatthias Springer template <typename OpTy, bool IsMin>
1510f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
1520f3544d1SMatthias Springer                                         ForOp partialIteration,
1530f3544d1SMatthias Springer                                         Value previousUb) {
1540f3544d1SMatthias Springer   Value mainIv = forOp.getInductionVar();
1550f3544d1SMatthias Springer   Value partialIv = partialIteration.getInductionVar();
156c0342a2dSJacques Pienaar   assert(forOp.getStep() == partialIteration.getStep() &&
1570f3544d1SMatthias Springer          "expected same step in main and partial loop");
158c0342a2dSJacques Pienaar   Value step = forOp.getStep();
1590f3544d1SMatthias Springer 
160a9cff97fSMatthias Springer   forOp.walk([&](OpTy affineOp) {
161c57c4f88SMatthias Springer     AffineMap map = affineOp.getAffineMap();
162c57c4f88SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
1630f3544d1SMatthias Springer                                      affineOp.operands(), IsMin, mainIv,
1640f3544d1SMatthias Springer                                      previousUb, step,
165a9cff97fSMatthias Springer                                      /*insideLoop=*/true);
166a9cff97fSMatthias Springer   });
1670f3544d1SMatthias Springer   partialIteration.walk([&](OpTy affineOp) {
168c57c4f88SMatthias Springer     AffineMap map = affineOp.getAffineMap();
169c57c4f88SMatthias Springer     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
1700f3544d1SMatthias Springer                                      affineOp.operands(), IsMin, partialIv,
1710f3544d1SMatthias Springer                                      previousUb, step, /*insideLoop=*/false);
172a9cff97fSMatthias Springer   });
1738e8b70aaSMatthias Springer }
1748e8b70aaSMatthias Springer 
1758e8b70aaSMatthias Springer LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
176bc194a5bSMatthias Springer                                                     ForOp forOp,
1770f3544d1SMatthias Springer                                                     ForOp &partialIteration) {
178c0342a2dSJacques Pienaar   Value previousUb = forOp.getUpperBound();
1798e8b70aaSMatthias Springer   Value splitBound;
1800f3544d1SMatthias Springer   if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound)))
1818e8b70aaSMatthias Springer     return failure();
1828e8b70aaSMatthias Springer 
183a9cff97fSMatthias Springer   // Rewrite affine.min and affine.max ops.
184a9cff97fSMatthias Springer   rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
1850f3544d1SMatthias Springer       rewriter, forOp, partialIteration, previousUb);
186a9cff97fSMatthias Springer   rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
1870f3544d1SMatthias Springer       rewriter, forOp, partialIteration, previousUb);
1888e8b70aaSMatthias Springer 
1898e8b70aaSMatthias Springer   return success();
1908e8b70aaSMatthias Springer }
1918e8b70aaSMatthias Springer 
1923a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
193bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
1943a41ff48SMatthias Springer 
1953a41ff48SMatthias Springer namespace {
1963a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
197bc194a5bSMatthias Springer   ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial)
198bc194a5bSMatthias Springer       : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {}
1993a41ff48SMatthias Springer 
2003a41ff48SMatthias Springer   LogicalResult matchAndRewrite(ForOp forOp,
2013a41ff48SMatthias Springer                                 PatternRewriter &rewriter) const override {
202bc194a5bSMatthias Springer     // Do not peel already peeled loops.
2033a41ff48SMatthias Springer     if (forOp->hasAttr(kPeeledLoopLabel))
2043a41ff48SMatthias Springer       return failure();
205bc194a5bSMatthias Springer     if (skipPartial) {
2060f3544d1SMatthias Springer       // No peeling of loops inside the partial iteration of another peeled
2070f3544d1SMatthias Springer       // loop.
208bc194a5bSMatthias Springer       Operation *op = forOp.getOperation();
2090f3544d1SMatthias Springer       while ((op = op->getParentOfType<scf::ForOp>())) {
210bc194a5bSMatthias Springer         if (op->hasAttr(kPartialIterationLabel))
211bc194a5bSMatthias Springer           return failure();
212bc194a5bSMatthias Springer       }
213bc194a5bSMatthias Springer     }
214bc194a5bSMatthias Springer     // Apply loop peeling.
2150f3544d1SMatthias Springer     scf::ForOp partialIteration;
2160f3544d1SMatthias Springer     if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration)))
2173a41ff48SMatthias Springer       return failure();
2183a41ff48SMatthias Springer     // Apply label, so that the same loop is not rewritten a second time.
2190f3544d1SMatthias Springer     partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2203a41ff48SMatthias Springer     rewriter.updateRootInPlace(forOp, [&]() {
2213a41ff48SMatthias Springer       forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2223a41ff48SMatthias Springer     });
2230f3544d1SMatthias Springer     partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
2243a41ff48SMatthias Springer     return success();
2253a41ff48SMatthias Springer   }
226bc194a5bSMatthias Springer 
227bc194a5bSMatthias Springer   /// If set to true, loops inside partial iterations of another peeled loop
228bc194a5bSMatthias Springer   /// are not peeled. This reduces the size of the generated code. Partial
229bc194a5bSMatthias Springer   /// iterations are not usually performance critical.
230bc194a5bSMatthias Springer   /// Note: Takes into account the entire chain of parent operations, not just
231bc194a5bSMatthias Springer   /// the direct parent.
232bc194a5bSMatthias Springer   bool skipPartial;
2333a41ff48SMatthias Springer };
2343a41ff48SMatthias Springer } // namespace
2353a41ff48SMatthias Springer 
2364bcd08ebSStephan Herhut namespace {
237*039b969bSMichele Scuttari struct ParallelLoopSpecialization
238*039b969bSMichele Scuttari     : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
23941574554SRiver Riddle   void runOnOperation() override {
24054998986SStella Laurenzo     getOperation()->walk(
2414bcd08ebSStephan Herhut         [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
2424bcd08ebSStephan Herhut   }
2434bcd08ebSStephan Herhut };
2444bcd08ebSStephan Herhut 
245*039b969bSMichele Scuttari struct ForLoopSpecialization
246*039b969bSMichele Scuttari     : public SCFForLoopSpecializationBase<ForLoopSpecialization> {
24741574554SRiver Riddle   void runOnOperation() override {
24854998986SStella Laurenzo     getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); });
2494bcd08ebSStephan Herhut   }
2504bcd08ebSStephan Herhut };
2513a41ff48SMatthias Springer 
252*039b969bSMichele Scuttari struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
25341574554SRiver Riddle   void runOnOperation() override {
25454998986SStella Laurenzo     auto *parentOp = getOperation();
25554998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
2563a41ff48SMatthias Springer     RewritePatternSet patterns(ctx);
257bc194a5bSMatthias Springer     patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
25854998986SStella Laurenzo     (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
2593a41ff48SMatthias Springer 
260bc194a5bSMatthias Springer     // Drop the markers.
26154998986SStella Laurenzo     parentOp->walk([](Operation *op) {
262bc194a5bSMatthias Springer       op->removeAttr(kPeeledLoopLabel);
263bc194a5bSMatthias Springer       op->removeAttr(kPartialIterationLabel);
264bc194a5bSMatthias Springer     });
2653a41ff48SMatthias Springer   }
2663a41ff48SMatthias Springer };
2674bcd08ebSStephan Herhut } // namespace
268*039b969bSMichele Scuttari 
269*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
270*039b969bSMichele Scuttari   return std::make_unique<ParallelLoopSpecialization>();
271*039b969bSMichele Scuttari }
272*039b969bSMichele Scuttari 
273*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
274*039b969bSMichele Scuttari   return std::make_unique<ForLoopSpecialization>();
275*039b969bSMichele Scuttari }
276*039b969bSMichele Scuttari 
277*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
278*039b969bSMichele Scuttari   return std::make_unique<ForLoopPeeling>();
279*039b969bSMichele Scuttari }
280