xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (revision b3a2208c566c475f7d1b6d40c67aec100ae29103)
1286bd42aSJorn Tuyls //===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2286bd42aSJorn Tuyls //
3286bd42aSJorn Tuyls // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4286bd42aSJorn Tuyls // See https://llvm.org/LICENSE.txt for license information.
5286bd42aSJorn Tuyls // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6286bd42aSJorn Tuyls //
7286bd42aSJorn Tuyls //===----------------------------------------------------------------------===//
8286bd42aSJorn Tuyls //
9286bd42aSJorn Tuyls // Transforms SCF.ForallOp's into SCF.ForOp's.
10286bd42aSJorn Tuyls //
11286bd42aSJorn Tuyls //===----------------------------------------------------------------------===//
12286bd42aSJorn Tuyls 
13286bd42aSJorn Tuyls #include "mlir/Dialect/SCF/Transforms/Passes.h"
14286bd42aSJorn Tuyls 
15286bd42aSJorn Tuyls #include "mlir/Dialect/SCF/IR/SCF.h"
16286bd42aSJorn Tuyls #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17286bd42aSJorn Tuyls #include "mlir/IR/PatternMatch.h"
18286bd42aSJorn Tuyls 
19286bd42aSJorn Tuyls namespace mlir {
20286bd42aSJorn Tuyls #define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21286bd42aSJorn Tuyls #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22286bd42aSJorn Tuyls } // namespace mlir
23286bd42aSJorn Tuyls 
24286bd42aSJorn Tuyls using namespace llvm;
25286bd42aSJorn Tuyls using namespace mlir;
26286bd42aSJorn Tuyls using scf::LoopNest;
27286bd42aSJorn Tuyls 
28286bd42aSJorn Tuyls LogicalResult
29286bd42aSJorn Tuyls mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
30286bd42aSJorn Tuyls                            SmallVectorImpl<Operation *> *results) {
31286bd42aSJorn Tuyls   OpBuilder::InsertionGuard guard(rewriter);
32286bd42aSJorn Tuyls   rewriter.setInsertionPoint(forallOp);
33286bd42aSJorn Tuyls 
34286bd42aSJorn Tuyls   Location loc = forallOp.getLoc();
35*6b4c1228Ssrcarroll   SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
36*6b4c1228Ssrcarroll   SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
37*6b4c1228Ssrcarroll   SmallVector<Value> steps = forallOp.getStep(rewriter);
38286bd42aSJorn Tuyls   LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
39286bd42aSJorn Tuyls 
40286bd42aSJorn Tuyls   SmallVector<Value> ivs = llvm::map_to_vector(
41286bd42aSJorn Tuyls       loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
42286bd42aSJorn Tuyls 
43286bd42aSJorn Tuyls   Block *innermostBlock = loopNest.loops.back().getBody();
44286bd42aSJorn Tuyls   rewriter.eraseOp(forallOp.getBody()->getTerminator());
45286bd42aSJorn Tuyls   rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
46286bd42aSJorn Tuyls                              innermostBlock->getTerminator()->getIterator(),
47286bd42aSJorn Tuyls                              ivs);
48286bd42aSJorn Tuyls   rewriter.eraseOp(forallOp);
49286bd42aSJorn Tuyls 
50286bd42aSJorn Tuyls   if (results) {
51286bd42aSJorn Tuyls     llvm::move(loopNest.loops, std::back_inserter(*results));
52286bd42aSJorn Tuyls   }
53286bd42aSJorn Tuyls 
54286bd42aSJorn Tuyls   return success();
55286bd42aSJorn Tuyls }
56286bd42aSJorn Tuyls 
57286bd42aSJorn Tuyls namespace {
58286bd42aSJorn Tuyls struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
59286bd42aSJorn Tuyls   void runOnOperation() override {
60286bd42aSJorn Tuyls     Operation *parentOp = getOperation();
61286bd42aSJorn Tuyls     IRRewriter rewriter(parentOp->getContext());
62286bd42aSJorn Tuyls 
63286bd42aSJorn Tuyls     parentOp->walk([&](scf::ForallOp forallOp) {
64286bd42aSJorn Tuyls       if (failed(scf::forallToForLoop(rewriter, forallOp))) {
65286bd42aSJorn Tuyls         return signalPassFailure();
66286bd42aSJorn Tuyls       }
67286bd42aSJorn Tuyls     });
68286bd42aSJorn Tuyls   }
69286bd42aSJorn Tuyls };
70286bd42aSJorn Tuyls } // namespace
71286bd42aSJorn Tuyls 
72286bd42aSJorn Tuyls std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
73286bd42aSJorn Tuyls   return std::make_unique<ForallToForLoop>();
74286bd42aSJorn Tuyls }
75