xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (revision b3a2208c566c475f7d1b6d40c67aec100ae29103)
1 //===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForallOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace llvm;
25 using namespace mlir;
26 using scf::LoopNest;
27 
28 LogicalResult
29 mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
30                            SmallVectorImpl<Operation *> *results) {
31   OpBuilder::InsertionGuard guard(rewriter);
32   rewriter.setInsertionPoint(forallOp);
33 
34   Location loc = forallOp.getLoc();
35   SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
36   SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
37   SmallVector<Value> steps = forallOp.getStep(rewriter);
38   LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
39 
40   SmallVector<Value> ivs = llvm::map_to_vector(
41       loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
42 
43   Block *innermostBlock = loopNest.loops.back().getBody();
44   rewriter.eraseOp(forallOp.getBody()->getTerminator());
45   rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
46                              innermostBlock->getTerminator()->getIterator(),
47                              ivs);
48   rewriter.eraseOp(forallOp);
49 
50   if (results) {
51     llvm::move(loopNest.loops, std::back_inserter(*results));
52   }
53 
54   return success();
55 }
56 
57 namespace {
58 struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
59   void runOnOperation() override {
60     Operation *parentOp = getOperation();
61     IRRewriter rewriter(parentOp->getContext());
62 
63     parentOp->walk([&](scf::ForallOp forallOp) {
64       if (failed(scf::forallToForLoop(rewriter, forallOp))) {
65         return signalPassFailure();
66       }
67     });
68   }
69 };
70 } // namespace
71 
72 std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
73   return std::make_unique<ForallToForLoop>();
74 }
75