xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp (revision 6b4c12284795a3030e37b17047271a47a69bb587)
10b665c3dSSpenser Bauman //===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
20b665c3dSSpenser Bauman //
30b665c3dSSpenser Bauman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b665c3dSSpenser Bauman // See https://llvm.org/LICENSE.txt for license information.
50b665c3dSSpenser Bauman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b665c3dSSpenser Bauman //
70b665c3dSSpenser Bauman //===----------------------------------------------------------------------===//
80b665c3dSSpenser Bauman //
90b665c3dSSpenser Bauman // Transforms SCF.ForallOp's into SCF.ParallelOps's.
100b665c3dSSpenser Bauman //
110b665c3dSSpenser Bauman //===----------------------------------------------------------------------===//
120b665c3dSSpenser Bauman 
130b665c3dSSpenser Bauman #include "mlir/Dialect/SCF/IR/SCF.h"
140b665c3dSSpenser Bauman #include "mlir/Dialect/SCF/Transforms/Passes.h"
150b665c3dSSpenser Bauman #include "mlir/Dialect/SCF/Transforms/Transforms.h"
160b665c3dSSpenser Bauman #include "mlir/IR/PatternMatch.h"
170b665c3dSSpenser Bauman 
180b665c3dSSpenser Bauman namespace mlir {
190b665c3dSSpenser Bauman #define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
200b665c3dSSpenser Bauman #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
210b665c3dSSpenser Bauman } // namespace mlir
220b665c3dSSpenser Bauman 
230b665c3dSSpenser Bauman using namespace mlir;
240b665c3dSSpenser Bauman 
forallToParallelLoop(RewriterBase & rewriter,scf::ForallOp forallOp,scf::ParallelOp * result)250b665c3dSSpenser Bauman LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
260b665c3dSSpenser Bauman                                               scf::ForallOp forallOp,
270b665c3dSSpenser Bauman                                               scf::ParallelOp *result) {
280b665c3dSSpenser Bauman   OpBuilder::InsertionGuard guard(rewriter);
290b665c3dSSpenser Bauman   rewriter.setInsertionPoint(forallOp);
300b665c3dSSpenser Bauman 
310b665c3dSSpenser Bauman   Location loc = forallOp.getLoc();
320b665c3dSSpenser Bauman   if (!forallOp.getOutputs().empty())
330b665c3dSSpenser Bauman     return rewriter.notifyMatchFailure(
340b665c3dSSpenser Bauman         forallOp,
350b665c3dSSpenser Bauman         "only fully bufferized scf.forall ops can be lowered to scf.parallel");
360b665c3dSSpenser Bauman 
370b665c3dSSpenser Bauman   // Convert mixed bounds and steps to SSA values.
38*6b4c1228Ssrcarroll   SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
39*6b4c1228Ssrcarroll   SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
40*6b4c1228Ssrcarroll   SmallVector<Value> steps = forallOp.getStep(rewriter);
410b665c3dSSpenser Bauman 
420b665c3dSSpenser Bauman   // Create empty scf.parallel op.
430b665c3dSSpenser Bauman   auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
440b665c3dSSpenser Bauman   rewriter.eraseBlock(&parallelOp.getRegion().front());
450b665c3dSSpenser Bauman   rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
460b665c3dSSpenser Bauman                               parallelOp.getRegion().begin());
470b665c3dSSpenser Bauman   // Replace the terminator.
480b665c3dSSpenser Bauman   rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
490b665c3dSSpenser Bauman   rewriter.replaceOpWithNewOp<scf::ReduceOp>(
500b665c3dSSpenser Bauman       parallelOp.getRegion().front().getTerminator());
510b665c3dSSpenser Bauman 
520b665c3dSSpenser Bauman   // If the mapping attribute is present, propagate to the new parallelOp.
530b665c3dSSpenser Bauman   if (forallOp.getMapping())
540b665c3dSSpenser Bauman     parallelOp->setAttr("mapping", *forallOp.getMapping());
550b665c3dSSpenser Bauman 
560b665c3dSSpenser Bauman   // Erase the scf.forall op.
570b665c3dSSpenser Bauman   rewriter.replaceOp(forallOp, parallelOp);
580b665c3dSSpenser Bauman 
590b665c3dSSpenser Bauman   if (result)
600b665c3dSSpenser Bauman     *result = parallelOp;
610b665c3dSSpenser Bauman 
620b665c3dSSpenser Bauman   return success();
630b665c3dSSpenser Bauman }
640b665c3dSSpenser Bauman 
650b665c3dSSpenser Bauman namespace {
660b665c3dSSpenser Bauman struct ForallToParallelLoop final
670b665c3dSSpenser Bauman     : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
runOnOperation__anonc6d64d9a0111::ForallToParallelLoop680b665c3dSSpenser Bauman   void runOnOperation() override {
690b665c3dSSpenser Bauman     Operation *parentOp = getOperation();
700b665c3dSSpenser Bauman     IRRewriter rewriter(parentOp->getContext());
710b665c3dSSpenser Bauman 
720b665c3dSSpenser Bauman     parentOp->walk([&](scf::ForallOp forallOp) {
730b665c3dSSpenser Bauman       if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
740b665c3dSSpenser Bauman         return signalPassFailure();
750b665c3dSSpenser Bauman       }
760b665c3dSSpenser Bauman     });
770b665c3dSSpenser Bauman   }
780b665c3dSSpenser Bauman };
790b665c3dSSpenser Bauman } // namespace
800b665c3dSSpenser Bauman 
createForallToParallelLoopPass()810b665c3dSSpenser Bauman std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
820b665c3dSSpenser Bauman   return std::make_unique<ForallToParallelLoop>();
830b665c3dSSpenser Bauman }
84