1 //===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms SCF.ForallOp's into SCF.ForOp's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 namespace mlir { 20 #define GEN_PASS_DEF_SCFFORALLTOFORLOOP 21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 22 } // namespace mlir 23 24 using namespace llvm; 25 using namespace mlir; 26 using scf::LoopNest; 27 28 LogicalResult 29 mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp, 30 SmallVectorImpl<Operation *> *results) { 31 OpBuilder::InsertionGuard guard(rewriter); 32 rewriter.setInsertionPoint(forallOp); 33 34 Location loc = forallOp.getLoc(); 35 SmallVector<Value> lbs = forallOp.getLowerBound(rewriter); 36 SmallVector<Value> ubs = forallOp.getUpperBound(rewriter); 37 SmallVector<Value> steps = forallOp.getStep(rewriter); 38 LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps); 39 40 SmallVector<Value> ivs = llvm::map_to_vector( 41 loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); }); 42 43 Block *innermostBlock = loopNest.loops.back().getBody(); 44 rewriter.eraseOp(forallOp.getBody()->getTerminator()); 45 rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock, 46 innermostBlock->getTerminator()->getIterator(), 47 ivs); 48 rewriter.eraseOp(forallOp); 49 50 if (results) { 51 llvm::move(loopNest.loops, std::back_inserter(*results)); 52 } 53 54 return success(); 55 } 56 57 namespace { 58 struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> { 59 void runOnOperation() override { 60 Operation *parentOp = getOperation(); 61 IRRewriter rewriter(parentOp->getContext()); 62 63 parentOp->walk([&](scf::ForallOp forallOp) { 64 if (failed(scf::forallToForLoop(rewriter, forallOp))) { 65 return signalPassFailure(); 66 } 67 }); 68 } 69 }; 70 } // namespace 71 72 std::unique_ptr<Pass> mlir::createForallToForLoopPass() { 73 return std::make_unique<ForallToForLoop>(); 74 } 75