//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Transforms SCF.ForallOp's into SCF.ParallelOps's. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" namespace mlir { #define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir using namespace mlir; LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, scf::ForallOp forallOp, scf::ParallelOp *result) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(forallOp); Location loc = forallOp.getLoc(); if (!forallOp.getOutputs().empty()) return rewriter.notifyMatchFailure( forallOp, "only fully bufferized scf.forall ops can be lowered to scf.parallel"); // Convert mixed bounds and steps to SSA values. SmallVector lbs = forallOp.getLowerBound(rewriter); SmallVector ubs = forallOp.getUpperBound(rewriter); SmallVector steps = forallOp.getStep(rewriter); // Create empty scf.parallel op. auto parallelOp = rewriter.create(loc, lbs, ubs, steps); rewriter.eraseBlock(¶llelOp.getRegion().front()); rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), parallelOp.getRegion().begin()); // Replace the terminator. rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); rewriter.replaceOpWithNewOp( parallelOp.getRegion().front().getTerminator()); // If the mapping attribute is present, propagate to the new parallelOp. if (forallOp.getMapping()) parallelOp->setAttr("mapping", *forallOp.getMapping()); // Erase the scf.forall op. rewriter.replaceOp(forallOp, parallelOp); if (result) *result = parallelOp; return success(); } namespace { struct ForallToParallelLoop final : public impl::SCFForallToParallelLoopBase { void runOnOperation() override { Operation *parentOp = getOperation(); IRRewriter rewriter(parentOp->getContext()); parentOp->walk([&](scf::ForallOp forallOp) { if (failed(scf::forallToParallelLoop(rewriter, forallOp))) { return signalPassFailure(); } }); } }; } // namespace std::unique_ptr mlir::createForallToParallelLoopPass() { return std::make_unique(); }