xref: /llvm-project/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp (revision c0a1597029385686942c7cbccb4e998c4b2ab6ef)
1ac9ee618SRazvan Lupusoru //===- LegalizeDataValues.cpp - -------------------------------------------===//
2ac9ee618SRazvan Lupusoru //
3ac9ee618SRazvan Lupusoru // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ac9ee618SRazvan Lupusoru // See https://llvm.org/LICENSE.txt for license information.
5ac9ee618SRazvan Lupusoru // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ac9ee618SRazvan Lupusoru //
7ac9ee618SRazvan Lupusoru //===----------------------------------------------------------------------===//
8ac9ee618SRazvan Lupusoru 
9ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
10ac9ee618SRazvan Lupusoru 
11ac9ee618SRazvan Lupusoru #include "mlir/Dialect/Func/IR/FuncOps.h"
12ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/OpenACC.h"
13ac9ee618SRazvan Lupusoru #include "mlir/Pass/Pass.h"
14ac9ee618SRazvan Lupusoru #include "mlir/Transforms/RegionUtils.h"
15ac9ee618SRazvan Lupusoru #include "llvm/Support/ErrorHandling.h"
16ac9ee618SRazvan Lupusoru 
17ac9ee618SRazvan Lupusoru namespace mlir {
18ac9ee618SRazvan Lupusoru namespace acc {
19ac9ee618SRazvan Lupusoru #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
20ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
21ac9ee618SRazvan Lupusoru } // namespace acc
22ac9ee618SRazvan Lupusoru } // namespace mlir
23ac9ee618SRazvan Lupusoru 
24ac9ee618SRazvan Lupusoru using namespace mlir;
25ac9ee618SRazvan Lupusoru 
26ac9ee618SRazvan Lupusoru namespace {
27ac9ee618SRazvan Lupusoru 
28ac9ee618SRazvan Lupusoru static bool insideAccComputeRegion(mlir::Operation *op) {
29ac9ee618SRazvan Lupusoru   mlir::Operation *parent{op->getParentOp()};
30ac9ee618SRazvan Lupusoru   while (parent) {
31ac9ee618SRazvan Lupusoru     if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32ac9ee618SRazvan Lupusoru       return true;
33ac9ee618SRazvan Lupusoru     }
34ac9ee618SRazvan Lupusoru     parent = parent->getParentOp();
35ac9ee618SRazvan Lupusoru   }
36ac9ee618SRazvan Lupusoru   return false;
37ac9ee618SRazvan Lupusoru }
38ac9ee618SRazvan Lupusoru 
39ac9ee618SRazvan Lupusoru static void collectPtrs(mlir::ValueRange operands,
40ac9ee618SRazvan Lupusoru                         llvm::SmallVector<std::pair<Value, Value>> &values,
41ac9ee618SRazvan Lupusoru                         bool hostToDevice) {
42ac9ee618SRazvan Lupusoru   for (auto operand : operands) {
43ac9ee618SRazvan Lupusoru     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
44ac9ee618SRazvan Lupusoru     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
45ac9ee618SRazvan Lupusoru     if (varPtr && accPtr) {
46ac9ee618SRazvan Lupusoru       if (hostToDevice)
47ac9ee618SRazvan Lupusoru         values.push_back({varPtr, accPtr});
48ac9ee618SRazvan Lupusoru       else
49ac9ee618SRazvan Lupusoru         values.push_back({accPtr, varPtr});
50ac9ee618SRazvan Lupusoru     }
51ac9ee618SRazvan Lupusoru   }
52ac9ee618SRazvan Lupusoru }
53ac9ee618SRazvan Lupusoru 
54ac9ee618SRazvan Lupusoru template <typename Op>
55ac9ee618SRazvan Lupusoru static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
56ac9ee618SRazvan Lupusoru                                                   Region &outerRegion) {
57ac9ee618SRazvan Lupusoru   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
58ac9ee618SRazvan Lupusoru     if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
59ac9ee618SRazvan Lupusoru       if constexpr (std::is_same_v<Op, acc::DataOp> ||
60ac9ee618SRazvan Lupusoru                     std::is_same_v<Op, acc::DeclareOp>) {
61ac9ee618SRazvan Lupusoru         // For data construct regions, only replace uses in contained compute
62ac9ee618SRazvan Lupusoru         // regions.
63ac9ee618SRazvan Lupusoru         if (insideAccComputeRegion(use.getOwner())) {
64ac9ee618SRazvan Lupusoru           use.set(replacement);
65ac9ee618SRazvan Lupusoru         }
66ac9ee618SRazvan Lupusoru       } else {
67ac9ee618SRazvan Lupusoru         use.set(replacement);
68ac9ee618SRazvan Lupusoru       }
69ac9ee618SRazvan Lupusoru     }
70ac9ee618SRazvan Lupusoru   }
71ac9ee618SRazvan Lupusoru }
72ac9ee618SRazvan Lupusoru 
73ac9ee618SRazvan Lupusoru template <typename Op>
74ac9ee618SRazvan Lupusoru static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
75ac9ee618SRazvan Lupusoru   llvm::SmallVector<std::pair<Value, Value>> values;
76ac9ee618SRazvan Lupusoru 
77ac9ee618SRazvan Lupusoru   if constexpr (std::is_same_v<Op, acc::LoopOp>) {
78ac9ee618SRazvan Lupusoru     collectPtrs(op.getReductionOperands(), values, hostToDevice);
79ac9ee618SRazvan Lupusoru     collectPtrs(op.getPrivateOperands(), values, hostToDevice);
80ac9ee618SRazvan Lupusoru   } else {
81ac9ee618SRazvan Lupusoru     collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
82ac9ee618SRazvan Lupusoru     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83ac9ee618SRazvan Lupusoru                   !std::is_same_v<Op, acc::DataOp> &&
84ac9ee618SRazvan Lupusoru                   !std::is_same_v<Op, acc::DeclareOp>) {
85ac9ee618SRazvan Lupusoru       collectPtrs(op.getReductionOperands(), values, hostToDevice);
86*c0a15970SRazvan Lupusoru       collectPtrs(op.getPrivateOperands(), values, hostToDevice);
87*c0a15970SRazvan Lupusoru       collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
88ac9ee618SRazvan Lupusoru     }
89ac9ee618SRazvan Lupusoru   }
90ac9ee618SRazvan Lupusoru 
91ac9ee618SRazvan Lupusoru   for (auto p : values)
92ac9ee618SRazvan Lupusoru     replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
93ac9ee618SRazvan Lupusoru                                               op.getRegion());
94ac9ee618SRazvan Lupusoru }
95ac9ee618SRazvan Lupusoru 
96ac9ee618SRazvan Lupusoru class LegalizeDataValuesInRegion
97ac9ee618SRazvan Lupusoru     : public acc::impl::LegalizeDataValuesInRegionBase<
98ac9ee618SRazvan Lupusoru           LegalizeDataValuesInRegion> {
99ac9ee618SRazvan Lupusoru public:
100ac9ee618SRazvan Lupusoru   using LegalizeDataValuesInRegionBase<
101ac9ee618SRazvan Lupusoru       LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
102ac9ee618SRazvan Lupusoru 
103ac9ee618SRazvan Lupusoru   void runOnOperation() override {
104ac9ee618SRazvan Lupusoru     func::FuncOp funcOp = getOperation();
105ac9ee618SRazvan Lupusoru     bool replaceHostVsDevice = this->hostToDevice.getValue();
106ac9ee618SRazvan Lupusoru 
107ac9ee618SRazvan Lupusoru     funcOp.walk([&](Operation *op) {
108ac9ee618SRazvan Lupusoru       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109ac9ee618SRazvan Lupusoru           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110ac9ee618SRazvan Lupusoru             applyToAccDataConstruct))
111ac9ee618SRazvan Lupusoru         return;
112ac9ee618SRazvan Lupusoru 
113ac9ee618SRazvan Lupusoru       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
114ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
115ac9ee618SRazvan Lupusoru       } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
116ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
117ac9ee618SRazvan Lupusoru       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
118ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
119ac9ee618SRazvan Lupusoru       } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
120ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121ac9ee618SRazvan Lupusoru       } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123ac9ee618SRazvan Lupusoru       } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124ac9ee618SRazvan Lupusoru         collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
125ac9ee618SRazvan Lupusoru       } else {
126ac9ee618SRazvan Lupusoru         llvm_unreachable("unsupported acc region op");
127ac9ee618SRazvan Lupusoru       }
128ac9ee618SRazvan Lupusoru     });
129ac9ee618SRazvan Lupusoru   }
130ac9ee618SRazvan Lupusoru };
131ac9ee618SRazvan Lupusoru 
132ac9ee618SRazvan Lupusoru } // end anonymous namespace
133