12cbf0b3dSlorenzo chelini //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop tiling on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko
1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1467d0d7acSMichele Scuttari
15c25b20c0SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
20c25b20c0SAlex Zinenko
2167d0d7acSMichele Scuttari namespace mlir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFPARALLELLOOPTILING
2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
2567d0d7acSMichele Scuttari
26c25b20c0SAlex Zinenko using namespace mlir;
27c25b20c0SAlex Zinenko using namespace mlir::scf;
28c25b20c0SAlex Zinenko
29c25b20c0SAlex Zinenko /// Tile a parallel loop of the form
3060f443bbSAlex Zinenko /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31c25b20c0SAlex Zinenko /// step (%arg4, %arg5)
32c25b20c0SAlex Zinenko ///
33c25b20c0SAlex Zinenko /// into
3460f443bbSAlex Zinenko /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
35c25b20c0SAlex Zinenko /// step (%arg4*tileSize[0],
36c25b20c0SAlex Zinenko /// %arg5*tileSize[1])
37cd730816STobias Gysi /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
38cd730816STobias Gysi /// min(%arg5*tileSize[1], %arg3-%i1))
39c25b20c0SAlex Zinenko /// step (%arg4, %arg5)
401e60678cSStephan Herhut ///
412d45e332Stashuang.zk /// or, when no-min-max-bounds is true, into
422d45e332Stashuang.zk /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
432d45e332Stashuang.zk /// step (%arg4*tileSize[0],
442d45e332Stashuang.zk /// %arg5*tileSize[1])
452d45e332Stashuang.zk /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
462d45e332Stashuang.zk /// %arg5*tileSize[1])
472d45e332Stashuang.zk /// step (%arg4, %arg5)
482d45e332Stashuang.zk /// %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
492d45e332Stashuang.zk /// (%j1 * %arg5 + %i1 < %arg3)
502d45e332Stashuang.zk /// scf.if (%inbound)
512d45e332Stashuang.zk /// ....
522d45e332Stashuang.zk ///
531e60678cSStephan Herhut /// where the uses of %i0 and %i1 in the loop body are replaced by
541e60678cSStephan Herhut /// %i0 + j0 and %i1 + %j1.
552cbf0b3dSlorenzo chelini ///
56c25b20c0SAlex Zinenko /// The old loop is replaced with the new one.
5709c18a66SAlexander Belyaev std::pair<ParallelOp, ParallelOp>
tileParallelLoop(ParallelOp op,ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)582d45e332Stashuang.zk mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
592d45e332Stashuang.zk bool noMinMaxBounds) {
60c25b20c0SAlex Zinenko OpBuilder b(op);
61a54f4eaeSMogball auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
62c25b20c0SAlex Zinenko SmallVector<Value, 2> tileSizeConstants;
63c0342a2dSJacques Pienaar tileSizeConstants.reserve(op.getUpperBound().size());
64c0342a2dSJacques Pienaar for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
65c25b20c0SAlex Zinenko if (i < tileSizes.size())
66c25b20c0SAlex Zinenko tileSizeConstants.push_back(
67a54f4eaeSMogball b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
68c25b20c0SAlex Zinenko else
69c25b20c0SAlex Zinenko // Just pick 1 for the remaining dimensions.
70a54f4eaeSMogball tileSizeConstants.push_back(
71a54f4eaeSMogball b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
72c25b20c0SAlex Zinenko }
73c25b20c0SAlex Zinenko
74c25b20c0SAlex Zinenko // Create the outer loop with adjusted steps.
75c25b20c0SAlex Zinenko SmallVector<Value, 2> newSteps;
76c0342a2dSJacques Pienaar newSteps.reserve(op.getStep().size());
77c0342a2dSJacques Pienaar for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
78a54f4eaeSMogball newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
79a54f4eaeSMogball std::get<1>(step)));
80c25b20c0SAlex Zinenko }
81c0342a2dSJacques Pienaar auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
82c0342a2dSJacques Pienaar op.getUpperBound(), newSteps);
83c25b20c0SAlex Zinenko b.setInsertionPointToStart(outerLoop.getBody());
84c25b20c0SAlex Zinenko
85c25b20c0SAlex Zinenko // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
86c25b20c0SAlex Zinenko auto minMap = AffineMap::get(
87c25b20c0SAlex Zinenko /*dimCount=*/3, /*symbolCount=*/0,
88c25b20c0SAlex Zinenko {getAffineDimExpr(/*position=*/0, b.getContext()),
89c25b20c0SAlex Zinenko getAffineDimExpr(/*position=*/1, b.getContext()) -
90c25b20c0SAlex Zinenko getAffineDimExpr(/*position=*/2, b.getContext())},
91c25b20c0SAlex Zinenko b.getContext());
92c25b20c0SAlex Zinenko
93c25b20c0SAlex Zinenko // Create the inner loop with adjusted bounds.
94c25b20c0SAlex Zinenko SmallVector<Value, 2> newBounds;
95c0342a2dSJacques Pienaar newBounds.reserve(op.getUpperBound().size());
962d45e332Stashuang.zk bool needInboundCheck = false;
979fa59e76SBenjamin Kramer for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
98c0342a2dSJacques Pienaar llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
99c0342a2dSJacques Pienaar outerLoop.getStep(), outerLoop.getInductionVars(),
100c0342a2dSJacques Pienaar op.getStep(), tileSizeConstants)) {
101cd730816STobias Gysi // Collect the statically known loop bounds
102cd730816STobias Gysi auto lowerBoundConstant =
103a54f4eaeSMogball dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
104cd730816STobias Gysi auto upperBoundConstant =
105a54f4eaeSMogball dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
106a54f4eaeSMogball auto stepConstant =
107a54f4eaeSMogball dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
108cd730816STobias Gysi auto tileSize =
109a54f4eaeSMogball cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
110cd730816STobias Gysi // If the loop bounds and the loop step are constant and if the number of
111cd730816STobias Gysi // loop iterations is an integer multiple of the tile size, we use a static
112cd730816STobias Gysi // bound for the inner loop.
113cd730816STobias Gysi if (lowerBoundConstant && upperBoundConstant && stepConstant) {
114a54f4eaeSMogball auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
115a54f4eaeSMogball lowerBoundConstant.value(),
116a54f4eaeSMogball stepConstant.value());
117cd730816STobias Gysi if (numIterations % tileSize == 0) {
118cd730816STobias Gysi newBounds.push_back(newStep);
119cd730816STobias Gysi continue;
120cd730816STobias Gysi }
121cd730816STobias Gysi }
1222d45e332Stashuang.zk
1232d45e332Stashuang.zk // For InboundCheck mode, just use the variable outer step
1242d45e332Stashuang.zk if (noMinMaxBounds) {
1252d45e332Stashuang.zk newBounds.push_back(newStep);
1262d45e332Stashuang.zk needInboundCheck = true;
1272d45e332Stashuang.zk continue;
1282d45e332Stashuang.zk }
1292d45e332Stashuang.zk
130cd730816STobias Gysi // Otherwise, we dynamically compute the bound for
131cd730816STobias Gysi // each iteration of the outer loop.
132cd730816STobias Gysi newBounds.push_back(
1334c48f016SMatthias Springer b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
134cd730816STobias Gysi ValueRange{newStep, upperBound, iv}));
135c25b20c0SAlex Zinenko }
136c25b20c0SAlex Zinenko auto innerLoop = b.create<ParallelOp>(
137c25b20c0SAlex Zinenko op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
138c0342a2dSJacques Pienaar op.getStep());
139c25b20c0SAlex Zinenko
1402d45e332Stashuang.zk if (noMinMaxBounds && needInboundCheck) {
1411e60678cSStephan Herhut b.setInsertionPointToStart(innerLoop.getBody());
1422d45e332Stashuang.zk // Insert in-bound check
1432d45e332Stashuang.zk Value inbound =
144a54f4eaeSMogball b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
1459fa59e76SBenjamin Kramer for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
146c0342a2dSJacques Pienaar llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
147c0342a2dSJacques Pienaar innerLoop.getInductionVars(), innerLoop.getStep())) {
1482d45e332Stashuang.zk // %in_bound = %in_bound &&
1492d45e332Stashuang.zk // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
150a54f4eaeSMogball Value index = b.create<arith::AddIOp>(
151a54f4eaeSMogball op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
1522d45e332Stashuang.zk outerIV);
153a54f4eaeSMogball Value dimInbound = b.create<arith::CmpIOp>(
154a54f4eaeSMogball op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
155a54f4eaeSMogball inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
1562d45e332Stashuang.zk }
1572d45e332Stashuang.zk auto ifInbound = b.create<IfOp>(op.getLoc(),
1582d45e332Stashuang.zk /*resultTypes*/ ArrayRef<Type>{}, inbound,
1592d45e332Stashuang.zk /*hasElseRegion*/ false);
160c0342a2dSJacques Pienaar ifInbound.getThenRegion().takeBody(op.getRegion());
161c0342a2dSJacques Pienaar Block &thenBlock = ifInbound.getThenRegion().front();
162*10056c82SMatthias Springer // Replace the scf.reduce terminator with an scf.yield terminator.
163*10056c82SMatthias Springer Operation *reduceOp = thenBlock.getTerminator();
164*10056c82SMatthias Springer b.setInsertionPointToEnd(&thenBlock);
165*10056c82SMatthias Springer b.create<scf::YieldOp>(reduceOp->getLoc());
166*10056c82SMatthias Springer reduceOp->erase();
1672d45e332Stashuang.zk b.setInsertionPointToStart(innerLoop.getBody());
168e4853be2SMehdi Amini for (const auto &ivs : llvm::enumerate(llvm::zip(
169e4853be2SMehdi Amini innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
170a54f4eaeSMogball auto newIndex = b.create<arith::AddIOp>(
171a54f4eaeSMogball op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
1722d45e332Stashuang.zk thenBlock.getArgument(ivs.index())
1732d45e332Stashuang.zk .replaceAllUsesExcept(newIndex, newIndex);
1742d45e332Stashuang.zk }
1755b569ed2SJeff Niu thenBlock.eraseArguments(0, thenBlock.getNumArguments());
1762d45e332Stashuang.zk } else {
177c0342a2dSJacques Pienaar innerLoop.getRegion().takeBody(op.getRegion());
1782d45e332Stashuang.zk b.setInsertionPointToStart(innerLoop.getBody());
1792d45e332Stashuang.zk for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
1802d45e332Stashuang.zk outerLoop.getInductionVars())) {
1812d45e332Stashuang.zk Value innerIndex = std::get<0>(ivs);
182a54f4eaeSMogball auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
183a54f4eaeSMogball std::get<1>(ivs));
1842d45e332Stashuang.zk innerIndex.replaceAllUsesExcept(newIndex, newIndex);
1852d45e332Stashuang.zk }
1861e60678cSStephan Herhut }
1871e60678cSStephan Herhut
188c25b20c0SAlex Zinenko op.erase();
18909c18a66SAlexander Belyaev return std::make_pair(outerLoop, innerLoop);
190c25b20c0SAlex Zinenko }
191c25b20c0SAlex Zinenko
192c25b20c0SAlex Zinenko namespace {
193039b969bSMichele Scuttari struct ParallelLoopTiling
19467d0d7acSMichele Scuttari : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> {
195039b969bSMichele Scuttari ParallelLoopTiling() = default;
ParallelLoopTiling__anond4f611d80111::ParallelLoopTiling196039b969bSMichele Scuttari explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
197039b969bSMichele Scuttari bool noMinMaxBounds = false) {
198039b969bSMichele Scuttari this->tileSizes = tileSizes;
199039b969bSMichele Scuttari this->noMinMaxBounds = noMinMaxBounds;
200039b969bSMichele Scuttari }
201c25b20c0SAlex Zinenko
runOnOperation__anond4f611d80111::ParallelLoopTiling20241574554SRiver Riddle void runOnOperation() override {
203e38c8bdcSJustin Fargnoli for (auto tileSize : tileSizes)
204e38c8bdcSJustin Fargnoli if (tileSize == 0) {
205e38c8bdcSJustin Fargnoli mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()),
206e38c8bdcSJustin Fargnoli "tile size cannot be 0");
207e38c8bdcSJustin Fargnoli return signalPassFailure();
208e38c8bdcSJustin Fargnoli }
20954998986SStella Laurenzo auto *parentOp = getOperation();
2106484567fSFrederik Gossen SmallVector<ParallelOp, 2> innermostPloops;
21154998986SStella Laurenzo getInnermostParallelLoops(parentOp, innermostPloops);
2126484567fSFrederik Gossen for (ParallelOp ploop : innermostPloops) {
213cd730816STobias Gysi // FIXME: Add reduction support.
2146484567fSFrederik Gossen if (ploop.getNumReductions() == 0)
2152d45e332Stashuang.zk tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
216c25b20c0SAlex Zinenko }
217c25b20c0SAlex Zinenko }
218c25b20c0SAlex Zinenko };
219c25b20c0SAlex Zinenko } // namespace
220039b969bSMichele Scuttari
221039b969bSMichele Scuttari std::unique_ptr<Pass>
createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)222039b969bSMichele Scuttari mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
223039b969bSMichele Scuttari bool noMinMaxBounds) {
224039b969bSMichele Scuttari return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
225039b969bSMichele Scuttari }
226