1 //===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
10
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/OpenACC/OpenACC.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTOPENACCTOSCF
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22
23 using namespace mlir;
24
25 //===----------------------------------------------------------------------===//
26 // Conversion patterns
27 //===----------------------------------------------------------------------===//
28
29 namespace {
30 /// Pattern to transform the `getIfCond` on operation without region into a
31 /// scf.if and move the operation into the `then` region.
32 template <typename OpTy>
33 class ExpandIfCondition : public OpRewritePattern<OpTy> {
34 using OpRewritePattern<OpTy>::OpRewritePattern;
35
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const36 LogicalResult matchAndRewrite(OpTy op,
37 PatternRewriter &rewriter) const override {
38 // Early exit if there is no condition.
39 if (!op.getIfCond())
40 return failure();
41
42 IntegerAttr constAttr;
43 if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
44 auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
45 op.getIfCond(), false);
46 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
47 auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
48 thenBodyBuilder.clone(*op.getOperation());
49 rewriter.eraseOp(op);
50 } else {
51 if (constAttr.getInt())
52 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
53 else
54 rewriter.eraseOp(op);
55 }
56 return success();
57 }
58 };
59 } // namespace
60
populateOpenACCToSCFConversionPatterns(RewritePatternSet & patterns)61 void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
62 patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
63 patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
64 patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
65 }
66
67 namespace {
68 struct ConvertOpenACCToSCFPass
69 : public impl::ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
70 void runOnOperation() override;
71 };
72 } // namespace
73
runOnOperation()74 void ConvertOpenACCToSCFPass::runOnOperation() {
75 auto op = getOperation();
76 auto *context = op.getContext();
77
78 RewritePatternSet patterns(context);
79 ConversionTarget target(*context);
80 populateOpenACCToSCFConversionPatterns(patterns);
81
82 target.addLegalDialect<scf::SCFDialect>();
83 target.addLegalDialect<acc::OpenACCDialect>();
84
85 target.addDynamicallyLegalOp<acc::EnterDataOp>(
86 [](acc::EnterDataOp op) { return !op.getIfCond(); });
87
88 target.addDynamicallyLegalOp<acc::ExitDataOp>(
89 [](acc::ExitDataOp op) { return !op.getIfCond(); });
90
91 target.addDynamicallyLegalOp<acc::UpdateOp>(
92 [](acc::UpdateOp op) { return !op.getIfCond(); });
93
94 if (failed(applyPartialConversion(op, target, std::move(patterns))))
95 signalPassFailure();
96 }
97
createConvertOpenACCToSCFPass()98 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
99 return std::make_unique<ConvertOpenACCToSCFPass>();
100 }
101