xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp (revision 10056c821a56a19cef732129e4e0c5883ae1ee49)
1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop tiling on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFPARALLELLOOPTILING
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 /// Tile a parallel loop of the form
30 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31 ///                                            step (%arg4, %arg5)
32 ///
33 /// into
34 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
35 ///                                            step (%arg4*tileSize[0],
36 ///                                                  %arg5*tileSize[1])
37 ///     scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
38 ///                                          min(%arg5*tileSize[1], %arg3-%i1))
39 ///                                      step (%arg4, %arg5)
40 ///
41 /// or, when no-min-max-bounds is true, into
42 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
43 ///                                            step (%arg4*tileSize[0],
44 ///                                                  %arg5*tileSize[1])
45 ///     scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
46 ///                                          %arg5*tileSize[1])
47 ///                                      step (%arg4, %arg5)
48 ///        %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
49 ///                   (%j1 * %arg5 + %i1 < %arg3)
50 ///        scf.if (%inbound)
51 ///          ....
52 ///
53 /// where the uses of %i0 and %i1 in the loop body are replaced by
54 /// %i0 + j0 and %i1 + %j1.
55 ///
56 /// The old loop is replaced with the new one.
57 std::pair<ParallelOp, ParallelOp>
tileParallelLoop(ParallelOp op,ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)58 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
59                             bool noMinMaxBounds) {
60   OpBuilder b(op);
61   auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
62   SmallVector<Value, 2> tileSizeConstants;
63   tileSizeConstants.reserve(op.getUpperBound().size());
64   for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
65     if (i < tileSizes.size())
66       tileSizeConstants.push_back(
67           b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
68     else
69       // Just pick 1 for the remaining dimensions.
70       tileSizeConstants.push_back(
71           b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
72   }
73 
74   // Create the outer loop with adjusted steps.
75   SmallVector<Value, 2> newSteps;
76   newSteps.reserve(op.getStep().size());
77   for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
78     newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
79                                                std::get<1>(step)));
80   }
81   auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
82                                         op.getUpperBound(), newSteps);
83   b.setInsertionPointToStart(outerLoop.getBody());
84 
85   // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
86   auto minMap = AffineMap::get(
87       /*dimCount=*/3, /*symbolCount=*/0,
88       {getAffineDimExpr(/*position=*/0, b.getContext()),
89        getAffineDimExpr(/*position=*/1, b.getContext()) -
90            getAffineDimExpr(/*position=*/2, b.getContext())},
91       b.getContext());
92 
93   // Create the inner loop with adjusted bounds.
94   SmallVector<Value, 2> newBounds;
95   newBounds.reserve(op.getUpperBound().size());
96   bool needInboundCheck = false;
97   for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
98        llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
99                  outerLoop.getStep(), outerLoop.getInductionVars(),
100                  op.getStep(), tileSizeConstants)) {
101     // Collect the statically known loop bounds
102     auto lowerBoundConstant =
103         dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
104     auto upperBoundConstant =
105         dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
106     auto stepConstant =
107         dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
108     auto tileSize =
109         cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
110     // If the loop bounds and the loop step are constant and if the number of
111     // loop iterations is an integer multiple of the tile size, we use a static
112     // bound for the inner loop.
113     if (lowerBoundConstant && upperBoundConstant && stepConstant) {
114       auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
115                                                 lowerBoundConstant.value(),
116                                             stepConstant.value());
117       if (numIterations % tileSize == 0) {
118         newBounds.push_back(newStep);
119         continue;
120       }
121     }
122 
123     // For InboundCheck mode, just use the variable outer step
124     if (noMinMaxBounds) {
125       newBounds.push_back(newStep);
126       needInboundCheck = true;
127       continue;
128     }
129 
130     // Otherwise, we dynamically compute the bound for
131     // each iteration of the outer loop.
132     newBounds.push_back(
133         b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
134                                       ValueRange{newStep, upperBound, iv}));
135   }
136   auto innerLoop = b.create<ParallelOp>(
137       op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
138       op.getStep());
139 
140   if (noMinMaxBounds && needInboundCheck) {
141     b.setInsertionPointToStart(innerLoop.getBody());
142     // Insert in-bound check
143     Value inbound =
144         b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
145     for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
146          llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
147                    innerLoop.getInductionVars(), innerLoop.getStep())) {
148       // %in_bound = %in_bound &&
149       //             (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
150       Value index = b.create<arith::AddIOp>(
151           op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
152           outerIV);
153       Value dimInbound = b.create<arith::CmpIOp>(
154           op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
155       inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
156     }
157     auto ifInbound = b.create<IfOp>(op.getLoc(),
158                                     /*resultTypes*/ ArrayRef<Type>{}, inbound,
159                                     /*hasElseRegion*/ false);
160     ifInbound.getThenRegion().takeBody(op.getRegion());
161     Block &thenBlock = ifInbound.getThenRegion().front();
162     // Replace the scf.reduce terminator with an scf.yield terminator.
163     Operation *reduceOp = thenBlock.getTerminator();
164     b.setInsertionPointToEnd(&thenBlock);
165     b.create<scf::YieldOp>(reduceOp->getLoc());
166     reduceOp->erase();
167     b.setInsertionPointToStart(innerLoop.getBody());
168     for (const auto &ivs : llvm::enumerate(llvm::zip(
169              innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
170       auto newIndex = b.create<arith::AddIOp>(
171           op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
172       thenBlock.getArgument(ivs.index())
173           .replaceAllUsesExcept(newIndex, newIndex);
174     }
175     thenBlock.eraseArguments(0, thenBlock.getNumArguments());
176   } else {
177     innerLoop.getRegion().takeBody(op.getRegion());
178     b.setInsertionPointToStart(innerLoop.getBody());
179     for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
180                               outerLoop.getInductionVars())) {
181       Value innerIndex = std::get<0>(ivs);
182       auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
183                                               std::get<1>(ivs));
184       innerIndex.replaceAllUsesExcept(newIndex, newIndex);
185     }
186   }
187 
188   op.erase();
189   return std::make_pair(outerLoop, innerLoop);
190 }
191 
192 namespace {
193 struct ParallelLoopTiling
194     : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> {
195   ParallelLoopTiling() = default;
ParallelLoopTiling__anond4f611d80111::ParallelLoopTiling196   explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
197                               bool noMinMaxBounds = false) {
198     this->tileSizes = tileSizes;
199     this->noMinMaxBounds = noMinMaxBounds;
200   }
201 
runOnOperation__anond4f611d80111::ParallelLoopTiling202   void runOnOperation() override {
203     for (auto tileSize : tileSizes)
204       if (tileSize == 0) {
205         mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()),
206                         "tile size cannot be 0");
207         return signalPassFailure();
208       }
209     auto *parentOp = getOperation();
210     SmallVector<ParallelOp, 2> innermostPloops;
211     getInnermostParallelLoops(parentOp, innermostPloops);
212     for (ParallelOp ploop : innermostPloops) {
213       // FIXME: Add reduction support.
214       if (ploop.getNumReductions() == 0)
215         tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
216     }
217   }
218 };
219 } // namespace
220 
221 std::unique_ptr<Pass>
createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,bool noMinMaxBounds)222 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
223                                    bool noMinMaxBounds) {
224   return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
225 }
226