1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.parallel operations into OpenMP 10 // parallel loops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 /// Converts SCF parallel operation into an OpenMP workshare loop construct. 25 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 26 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 27 28 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 29 PatternRewriter &rewriter) const override { 30 // TODO: add support for reductions when OpenMP loops have them. 31 if (parallelOp.getNumResults() != 0) 32 return rewriter.notifyMatchFailure( 33 parallelOp, 34 "OpenMP dialect does not yet support loops with reductions"); 35 36 // Replace SCF yield with OpenMP yield. 37 { 38 OpBuilder::InsertionGuard guard(rewriter); 39 rewriter.setInsertionPointToEnd(parallelOp.getBody()); 40 assert(llvm::hasSingleElement(parallelOp.region()) && 41 "expected scf.parallel to have one block"); 42 rewriter.replaceOpWithNewOp<omp::YieldOp>( 43 parallelOp.getBody()->getTerminator(), ValueRange()); 44 } 45 46 // Replace the loop. 47 auto loop = rewriter.create<omp::WsLoopOp>( 48 parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), 49 parallelOp.step()); 50 rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), 51 loop.region().begin()); 52 rewriter.eraseOp(parallelOp); 53 return success(); 54 } 55 }; 56 57 /// Inserts OpenMP "parallel" operations around top-level SCF "parallel" 58 /// operations in the given function. This is implemented as a direct IR 59 /// modification rather than as a conversion pattern because it does not 60 /// modify the top-level operation it matches, which is a requirement for 61 /// rewrite patterns. 62 // 63 // TODO: consider creating nested parallel operations when necessary. 64 static void insertOpenMPParallel(FuncOp func) { 65 // Collect top-level SCF "parallel" ops. 66 SmallVector<scf::ParallelOp, 4> topLevelParallelOps; 67 func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) { 68 // Ignore ops that are already within OpenMP parallel construct. 69 if (!parallelOp.getParentOfType<scf::ParallelOp>()) 70 topLevelParallelOps.push_back(parallelOp); 71 }); 72 73 // Wrap SCF ops into OpenMP "parallel" ops. 74 for (scf::ParallelOp parallelOp : topLevelParallelOps) { 75 OpBuilder builder(parallelOp); 76 auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc()); 77 Block *block = builder.createBlock(&omp.getRegion()); 78 builder.create<omp::TerminatorOp>(parallelOp.getLoc()); 79 block->getOperations().splice( 80 block->begin(), parallelOp.getOperation()->getBlock()->getOperations(), 81 parallelOp.getOperation()); 82 } 83 } 84 85 /// Applies the conversion patterns in the given function. 86 static LogicalResult applyPatterns(FuncOp func) { 87 ConversionTarget target(*func.getContext()); 88 target.addIllegalOp<scf::ParallelOp>(); 89 target.addDynamicallyLegalOp<scf::YieldOp>( 90 [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op.getParentOp()); }); 91 target.addLegalDialect<omp::OpenMPDialect>(); 92 93 OwningRewritePatternList patterns; 94 patterns.insert<ParallelOpLowering>(func.getContext()); 95 FrozenRewritePatternList frozen(std::move(patterns)); 96 return applyPartialConversion(func, target, frozen); 97 } 98 99 /// A pass converting SCF operations to OpenMP operations. 100 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> { 101 /// Pass entry point. 102 void runOnFunction() override { 103 insertOpenMPParallel(getFunction()); 104 if (failed(applyPatterns(getFunction()))) 105 signalPassFailure(); 106 } 107 }; 108 109 } // end namespace 110 111 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() { 112 return std::make_unique<SCFToOpenMPPass>(); 113 } 114