xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp (revision 10056c821a56a19cef732129e4e0c5883ae1ee49)
12cbf0b3dSlorenzo chelini //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop tiling on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1467d0d7acSMichele Scuttari 
15c25b20c0SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
20c25b20c0SAlex Zinenko 
2167d0d7acSMichele Scuttari namespace mlir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPTILING
2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
2567d0d7acSMichele Scuttari 
26c25b20c0SAlex Zinenko using namespace mlir;
27c25b20c0SAlex Zinenko using namespace mlir::scf;
28c25b20c0SAlex Zinenko 
29c25b20c0SAlex Zinenko /// Tile a parallel loop of the form
3060f443bbSAlex Zinenko ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31c25b20c0SAlex Zinenko ///                                            step (%arg4, %arg5)
32c25b20c0SAlex Zinenko ///
33c25b20c0SAlex Zinenko /// into
3460f443bbSAlex Zinenko ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
35c25b20c0SAlex Zinenko ///                                            step (%arg4*tileSize[0],
36c25b20c0SAlex Zinenko ///                                                  %arg5*tileSize[1])
37cd730816STobias Gysi ///     scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
38cd730816STobias Gysi ///                                          min(%arg5*tileSize[1], %arg3-%i1))
39c25b20c0SAlex Zinenko ///                                      step (%arg4, %arg5)
401e60678cSStephan Herhut ///
412d45e332Stashuang.zk /// or, when no-min-max-bounds is true, into
422d45e332Stashuang.zk ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
432d45e332Stashuang.zk ///                                            step (%arg4*tileSize[0],
442d45e332Stashuang.zk ///                                                  %arg5*tileSize[1])
452d45e332Stashuang.zk ///     scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
462d45e332Stashuang.zk ///                                          %arg5*tileSize[1])
472d45e332Stashuang.zk ///                                      step (%arg4, %arg5)
482d45e332Stashuang.zk ///        %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
492d45e332Stashuang.zk ///                   (%j1 * %arg5 + %i1 < %arg3)
502d45e332Stashuang.zk ///        scf.if (%inbound)
512d45e332Stashuang.zk ///          ....
522d45e332Stashuang.zk ///
531e60678cSStephan Herhut /// where the uses of %i0 and %i1 in the loop body are replaced by
541e60678cSStephan Herhut /// %i0 + j0 and %i1 + %j1.
552cbf0b3dSlorenzo chelini ///
56c25b20c0SAlex Zinenko /// The old loop is replaced with the new one.
5709c18a66SAlexander Belyaev std::pair<ParallelOp, ParallelOp>
tileParallelLoop(ParallelOp op,ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)582d45e332Stashuang.zk mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
592d45e332Stashuang.zk                             bool noMinMaxBounds) {
60c25b20c0SAlex Zinenko   OpBuilder b(op);
61a54f4eaeSMogball   auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
62c25b20c0SAlex Zinenko   SmallVector<Value, 2> tileSizeConstants;
63c0342a2dSJacques Pienaar   tileSizeConstants.reserve(op.getUpperBound().size());
64c0342a2dSJacques Pienaar   for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
65c25b20c0SAlex Zinenko     if (i < tileSizes.size())
66c25b20c0SAlex Zinenko       tileSizeConstants.push_back(
67a54f4eaeSMogball           b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
68c25b20c0SAlex Zinenko     else
69c25b20c0SAlex Zinenko       // Just pick 1 for the remaining dimensions.
70a54f4eaeSMogball       tileSizeConstants.push_back(
71a54f4eaeSMogball           b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
72c25b20c0SAlex Zinenko   }
73c25b20c0SAlex Zinenko 
74c25b20c0SAlex Zinenko   // Create the outer loop with adjusted steps.
75c25b20c0SAlex Zinenko   SmallVector<Value, 2> newSteps;
76c0342a2dSJacques Pienaar   newSteps.reserve(op.getStep().size());
77c0342a2dSJacques Pienaar   for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
78a54f4eaeSMogball     newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
79a54f4eaeSMogball                                                std::get<1>(step)));
80c25b20c0SAlex Zinenko   }
81c0342a2dSJacques Pienaar   auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
82c0342a2dSJacques Pienaar                                         op.getUpperBound(), newSteps);
83c25b20c0SAlex Zinenko   b.setInsertionPointToStart(outerLoop.getBody());
84c25b20c0SAlex Zinenko 
85c25b20c0SAlex Zinenko   // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
86c25b20c0SAlex Zinenko   auto minMap = AffineMap::get(
87c25b20c0SAlex Zinenko       /*dimCount=*/3, /*symbolCount=*/0,
88c25b20c0SAlex Zinenko       {getAffineDimExpr(/*position=*/0, b.getContext()),
89c25b20c0SAlex Zinenko        getAffineDimExpr(/*position=*/1, b.getContext()) -
90c25b20c0SAlex Zinenko            getAffineDimExpr(/*position=*/2, b.getContext())},
91c25b20c0SAlex Zinenko       b.getContext());
92c25b20c0SAlex Zinenko 
93c25b20c0SAlex Zinenko   // Create the inner loop with adjusted bounds.
94c25b20c0SAlex Zinenko   SmallVector<Value, 2> newBounds;
95c0342a2dSJacques Pienaar   newBounds.reserve(op.getUpperBound().size());
962d45e332Stashuang.zk   bool needInboundCheck = false;
979fa59e76SBenjamin Kramer   for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
98c0342a2dSJacques Pienaar        llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
99c0342a2dSJacques Pienaar                  outerLoop.getStep(), outerLoop.getInductionVars(),
100c0342a2dSJacques Pienaar                  op.getStep(), tileSizeConstants)) {
101cd730816STobias Gysi     // Collect the statically known loop bounds
102cd730816STobias Gysi     auto lowerBoundConstant =
103a54f4eaeSMogball         dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
104cd730816STobias Gysi     auto upperBoundConstant =
105a54f4eaeSMogball         dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
106a54f4eaeSMogball     auto stepConstant =
107a54f4eaeSMogball         dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
108cd730816STobias Gysi     auto tileSize =
109a54f4eaeSMogball         cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
110cd730816STobias Gysi     // If the loop bounds and the loop step are constant and if the number of
111cd730816STobias Gysi     // loop iterations is an integer multiple of the tile size, we use a static
112cd730816STobias Gysi     // bound for the inner loop.
113cd730816STobias Gysi     if (lowerBoundConstant && upperBoundConstant && stepConstant) {
114a54f4eaeSMogball       auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
115a54f4eaeSMogball                                                 lowerBoundConstant.value(),
116a54f4eaeSMogball                                             stepConstant.value());
117cd730816STobias Gysi       if (numIterations % tileSize == 0) {
118cd730816STobias Gysi         newBounds.push_back(newStep);
119cd730816STobias Gysi         continue;
120cd730816STobias Gysi       }
121cd730816STobias Gysi     }
1222d45e332Stashuang.zk 
1232d45e332Stashuang.zk     // For InboundCheck mode, just use the variable outer step
1242d45e332Stashuang.zk     if (noMinMaxBounds) {
1252d45e332Stashuang.zk       newBounds.push_back(newStep);
1262d45e332Stashuang.zk       needInboundCheck = true;
1272d45e332Stashuang.zk       continue;
1282d45e332Stashuang.zk     }
1292d45e332Stashuang.zk 
130cd730816STobias Gysi     // Otherwise, we dynamically compute the bound for
131cd730816STobias Gysi     // each iteration of the outer loop.
132cd730816STobias Gysi     newBounds.push_back(
1334c48f016SMatthias Springer         b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
134cd730816STobias Gysi                                       ValueRange{newStep, upperBound, iv}));
135c25b20c0SAlex Zinenko   }
136c25b20c0SAlex Zinenko   auto innerLoop = b.create<ParallelOp>(
137c25b20c0SAlex Zinenko       op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
138c0342a2dSJacques Pienaar       op.getStep());
139c25b20c0SAlex Zinenko 
1402d45e332Stashuang.zk   if (noMinMaxBounds && needInboundCheck) {
1411e60678cSStephan Herhut     b.setInsertionPointToStart(innerLoop.getBody());
1422d45e332Stashuang.zk     // Insert in-bound check
1432d45e332Stashuang.zk     Value inbound =
144a54f4eaeSMogball         b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
1459fa59e76SBenjamin Kramer     for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
146c0342a2dSJacques Pienaar          llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
147c0342a2dSJacques Pienaar                    innerLoop.getInductionVars(), innerLoop.getStep())) {
1482d45e332Stashuang.zk       // %in_bound = %in_bound &&
1492d45e332Stashuang.zk       //             (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
150a54f4eaeSMogball       Value index = b.create<arith::AddIOp>(
151a54f4eaeSMogball           op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
1522d45e332Stashuang.zk           outerIV);
153a54f4eaeSMogball       Value dimInbound = b.create<arith::CmpIOp>(
154a54f4eaeSMogball           op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
155a54f4eaeSMogball       inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
1562d45e332Stashuang.zk     }
1572d45e332Stashuang.zk     auto ifInbound = b.create<IfOp>(op.getLoc(),
1582d45e332Stashuang.zk                                     /*resultTypes*/ ArrayRef<Type>{}, inbound,
1592d45e332Stashuang.zk                                     /*hasElseRegion*/ false);
160c0342a2dSJacques Pienaar     ifInbound.getThenRegion().takeBody(op.getRegion());
161c0342a2dSJacques Pienaar     Block &thenBlock = ifInbound.getThenRegion().front();
162*10056c82SMatthias Springer     // Replace the scf.reduce terminator with an scf.yield terminator.
163*10056c82SMatthias Springer     Operation *reduceOp = thenBlock.getTerminator();
164*10056c82SMatthias Springer     b.setInsertionPointToEnd(&thenBlock);
165*10056c82SMatthias Springer     b.create<scf::YieldOp>(reduceOp->getLoc());
166*10056c82SMatthias Springer     reduceOp->erase();
1672d45e332Stashuang.zk     b.setInsertionPointToStart(innerLoop.getBody());
168e4853be2SMehdi Amini     for (const auto &ivs : llvm::enumerate(llvm::zip(
169e4853be2SMehdi Amini              innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
170a54f4eaeSMogball       auto newIndex = b.create<arith::AddIOp>(
171a54f4eaeSMogball           op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
1722d45e332Stashuang.zk       thenBlock.getArgument(ivs.index())
1732d45e332Stashuang.zk           .replaceAllUsesExcept(newIndex, newIndex);
1742d45e332Stashuang.zk     }
1755b569ed2SJeff Niu     thenBlock.eraseArguments(0, thenBlock.getNumArguments());
1762d45e332Stashuang.zk   } else {
177c0342a2dSJacques Pienaar     innerLoop.getRegion().takeBody(op.getRegion());
1782d45e332Stashuang.zk     b.setInsertionPointToStart(innerLoop.getBody());
1792d45e332Stashuang.zk     for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
1802d45e332Stashuang.zk                               outerLoop.getInductionVars())) {
1812d45e332Stashuang.zk       Value innerIndex = std::get<0>(ivs);
182a54f4eaeSMogball       auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
183a54f4eaeSMogball                                               std::get<1>(ivs));
1842d45e332Stashuang.zk       innerIndex.replaceAllUsesExcept(newIndex, newIndex);
1852d45e332Stashuang.zk     }
1861e60678cSStephan Herhut   }
1871e60678cSStephan Herhut 
188c25b20c0SAlex Zinenko   op.erase();
18909c18a66SAlexander Belyaev   return std::make_pair(outerLoop, innerLoop);
190c25b20c0SAlex Zinenko }
191c25b20c0SAlex Zinenko 
192c25b20c0SAlex Zinenko namespace {
193039b969bSMichele Scuttari struct ParallelLoopTiling
19467d0d7acSMichele Scuttari     : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> {
195039b969bSMichele Scuttari   ParallelLoopTiling() = default;
ParallelLoopTiling__anond4f611d80111::ParallelLoopTiling196039b969bSMichele Scuttari   explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
197039b969bSMichele Scuttari                               bool noMinMaxBounds = false) {
198039b969bSMichele Scuttari     this->tileSizes = tileSizes;
199039b969bSMichele Scuttari     this->noMinMaxBounds = noMinMaxBounds;
200039b969bSMichele Scuttari   }
201c25b20c0SAlex Zinenko 
runOnOperation__anond4f611d80111::ParallelLoopTiling20241574554SRiver Riddle   void runOnOperation() override {
203e38c8bdcSJustin Fargnoli     for (auto tileSize : tileSizes)
204e38c8bdcSJustin Fargnoli       if (tileSize == 0) {
205e38c8bdcSJustin Fargnoli         mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()),
206e38c8bdcSJustin Fargnoli                         "tile size cannot be 0");
207e38c8bdcSJustin Fargnoli         return signalPassFailure();
208e38c8bdcSJustin Fargnoli       }
20954998986SStella Laurenzo     auto *parentOp = getOperation();
2106484567fSFrederik Gossen     SmallVector<ParallelOp, 2> innermostPloops;
21154998986SStella Laurenzo     getInnermostParallelLoops(parentOp, innermostPloops);
2126484567fSFrederik Gossen     for (ParallelOp ploop : innermostPloops) {
213cd730816STobias Gysi       // FIXME: Add reduction support.
2146484567fSFrederik Gossen       if (ploop.getNumReductions() == 0)
2152d45e332Stashuang.zk         tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
216c25b20c0SAlex Zinenko     }
217c25b20c0SAlex Zinenko   }
218c25b20c0SAlex Zinenko };
219c25b20c0SAlex Zinenko } // namespace
220039b969bSMichele Scuttari 
221039b969bSMichele Scuttari std::unique_ptr<Pass>
createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)222039b969bSMichele Scuttari mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
223039b969bSMichele Scuttari                                    bool noMinMaxBounds) {
224039b969bSMichele Scuttari   return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
225039b969bSMichele Scuttari }
226