xref: /llvm-project/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp (revision c0a1597029385686942c7cbccb4e998c4b2ab6ef)
1 //===- LegalizeDataValues.cpp - -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/OpenACC/OpenACC.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/RegionUtils.h"
15 #include "llvm/Support/ErrorHandling.h"
16 
17 namespace mlir {
18 namespace acc {
19 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
20 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
21 } // namespace acc
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 static bool insideAccComputeRegion(mlir::Operation *op) {
29   mlir::Operation *parent{op->getParentOp()};
30   while (parent) {
31     if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32       return true;
33     }
34     parent = parent->getParentOp();
35   }
36   return false;
37 }
38 
39 static void collectPtrs(mlir::ValueRange operands,
40                         llvm::SmallVector<std::pair<Value, Value>> &values,
41                         bool hostToDevice) {
42   for (auto operand : operands) {
43     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
44     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
45     if (varPtr && accPtr) {
46       if (hostToDevice)
47         values.push_back({varPtr, accPtr});
48       else
49         values.push_back({accPtr, varPtr});
50     }
51   }
52 }
53 
54 template <typename Op>
55 static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
56                                                   Region &outerRegion) {
57   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
58     if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
59       if constexpr (std::is_same_v<Op, acc::DataOp> ||
60                     std::is_same_v<Op, acc::DeclareOp>) {
61         // For data construct regions, only replace uses in contained compute
62         // regions.
63         if (insideAccComputeRegion(use.getOwner())) {
64           use.set(replacement);
65         }
66       } else {
67         use.set(replacement);
68       }
69     }
70   }
71 }
72 
73 template <typename Op>
74 static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
75   llvm::SmallVector<std::pair<Value, Value>> values;
76 
77   if constexpr (std::is_same_v<Op, acc::LoopOp>) {
78     collectPtrs(op.getReductionOperands(), values, hostToDevice);
79     collectPtrs(op.getPrivateOperands(), values, hostToDevice);
80   } else {
81     collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
82     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83                   !std::is_same_v<Op, acc::DataOp> &&
84                   !std::is_same_v<Op, acc::DeclareOp>) {
85       collectPtrs(op.getReductionOperands(), values, hostToDevice);
86       collectPtrs(op.getPrivateOperands(), values, hostToDevice);
87       collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
88     }
89   }
90 
91   for (auto p : values)
92     replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
93                                               op.getRegion());
94 }
95 
96 class LegalizeDataValuesInRegion
97     : public acc::impl::LegalizeDataValuesInRegionBase<
98           LegalizeDataValuesInRegion> {
99 public:
100   using LegalizeDataValuesInRegionBase<
101       LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
102 
103   void runOnOperation() override {
104     func::FuncOp funcOp = getOperation();
105     bool replaceHostVsDevice = this->hostToDevice.getValue();
106 
107     funcOp.walk([&](Operation *op) {
108       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110             applyToAccDataConstruct))
111         return;
112 
113       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
114         collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
115       } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
116         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
117       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
118         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
119       } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
120         collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121       } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122         collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123       } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124         collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
125       } else {
126         llvm_unreachable("unsupported acc region op");
127       }
128     });
129   }
130 };
131 
132 } // end anonymous namespace
133