1 //===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms SCF.ForallOp's into SCF.ParallelOps's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/Dialect/SCF/Transforms/Passes.h" 15 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 namespace mlir { 19 #define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP 20 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 21 } // namespace mlir 22 23 using namespace mlir; 24 25 LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, 26 scf::ForallOp forallOp, 27 scf::ParallelOp *result) { 28 OpBuilder::InsertionGuard guard(rewriter); 29 rewriter.setInsertionPoint(forallOp); 30 31 Location loc = forallOp.getLoc(); 32 if (!forallOp.getOutputs().empty()) 33 return rewriter.notifyMatchFailure( 34 forallOp, 35 "only fully bufferized scf.forall ops can be lowered to scf.parallel"); 36 37 // Convert mixed bounds and steps to SSA values. 38 SmallVector<Value> lbs = getValueOrCreateConstantIndexOp( 39 rewriter, loc, forallOp.getMixedLowerBound()); 40 SmallVector<Value> ubs = getValueOrCreateConstantIndexOp( 41 rewriter, loc, forallOp.getMixedUpperBound()); 42 SmallVector<Value> steps = 43 getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); 44 45 // Create empty scf.parallel op. 46 auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps); 47 rewriter.eraseBlock(¶llelOp.getRegion().front()); 48 rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), 49 parallelOp.getRegion().begin()); 50 // Replace the terminator. 51 rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); 52 rewriter.replaceOpWithNewOp<scf::ReduceOp>( 53 parallelOp.getRegion().front().getTerminator()); 54 55 // If the mapping attribute is present, propagate to the new parallelOp. 56 if (forallOp.getMapping()) 57 parallelOp->setAttr("mapping", *forallOp.getMapping()); 58 59 // Erase the scf.forall op. 60 rewriter.replaceOp(forallOp, parallelOp); 61 62 if (result) 63 *result = parallelOp; 64 65 return success(); 66 } 67 68 namespace { 69 struct ForallToParallelLoop final 70 : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> { 71 void runOnOperation() override { 72 Operation *parentOp = getOperation(); 73 IRRewriter rewriter(parentOp->getContext()); 74 75 parentOp->walk([&](scf::ForallOp forallOp) { 76 if (failed(scf::forallToParallelLoop(rewriter, forallOp))) { 77 return signalPassFailure(); 78 } 79 }); 80 } 81 }; 82 } // namespace 83 84 std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() { 85 return std::make_unique<ForallToParallelLoop>(); 86 } 87