1 //===- LegalizeDataValues.cpp - -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/OpenACC/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/OpenACC/OpenACC.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/RegionUtils.h" 15 #include "llvm/Support/ErrorHandling.h" 16 17 namespace mlir { 18 namespace acc { 19 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION 20 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" 21 } // namespace acc 22 } // namespace mlir 23 24 using namespace mlir; 25 26 namespace { 27 28 static bool insideAccComputeRegion(mlir::Operation *op) { 29 mlir::Operation *parent{op->getParentOp()}; 30 while (parent) { 31 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) { 32 return true; 33 } 34 parent = parent->getParentOp(); 35 } 36 return false; 37 } 38 39 static void collectPtrs(mlir::ValueRange operands, 40 llvm::SmallVector<std::pair<Value, Value>> &values, 41 bool hostToDevice) { 42 for (auto operand : operands) { 43 Value varPtr = acc::getVarPtr(operand.getDefiningOp()); 44 Value accPtr = acc::getAccPtr(operand.getDefiningOp()); 45 if (varPtr && accPtr) { 46 if (hostToDevice) 47 values.push_back({varPtr, accPtr}); 48 else 49 values.push_back({accPtr, varPtr}); 50 } 51 } 52 } 53 54 template <typename Op> 55 static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, 56 Region &outerRegion) { 57 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 58 if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) { 59 if constexpr (std::is_same_v<Op, acc::DataOp> || 60 std::is_same_v<Op, acc::DeclareOp>) { 61 // For data construct regions, only replace uses in contained compute 62 // regions. 63 if (insideAccComputeRegion(use.getOwner())) { 64 use.set(replacement); 65 } 66 } else { 67 use.set(replacement); 68 } 69 } 70 } 71 } 72 73 template <typename Op> 74 static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { 75 llvm::SmallVector<std::pair<Value, Value>> values; 76 77 if constexpr (std::is_same_v<Op, acc::LoopOp>) { 78 collectPtrs(op.getReductionOperands(), values, hostToDevice); 79 collectPtrs(op.getPrivateOperands(), values, hostToDevice); 80 } else { 81 collectPtrs(op.getDataClauseOperands(), values, hostToDevice); 82 if constexpr (!std::is_same_v<Op, acc::KernelsOp> && 83 !std::is_same_v<Op, acc::DataOp> && 84 !std::is_same_v<Op, acc::DeclareOp>) { 85 collectPtrs(op.getReductionOperands(), values, hostToDevice); 86 collectPtrs(op.getPrivateOperands(), values, hostToDevice); 87 collectPtrs(op.getFirstprivateOperands(), values, hostToDevice); 88 } 89 } 90 91 for (auto p : values) 92 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p), 93 op.getRegion()); 94 } 95 96 class LegalizeDataValuesInRegion 97 : public acc::impl::LegalizeDataValuesInRegionBase< 98 LegalizeDataValuesInRegion> { 99 public: 100 using LegalizeDataValuesInRegionBase< 101 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase; 102 103 void runOnOperation() override { 104 func::FuncOp funcOp = getOperation(); 105 bool replaceHostVsDevice = this->hostToDevice.getValue(); 106 107 funcOp.walk([&](Operation *op) { 108 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) && 109 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) && 110 applyToAccDataConstruct)) 111 return; 112 113 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { 114 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); 115 } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { 116 collectAndReplaceInRegion(serialOp, replaceHostVsDevice); 117 } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { 118 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); 119 } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) { 120 collectAndReplaceInRegion(loopOp, replaceHostVsDevice); 121 } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) { 122 collectAndReplaceInRegion(dataOp, replaceHostVsDevice); 123 } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) { 124 collectAndReplaceInRegion(declareOp, replaceHostVsDevice); 125 } else { 126 llvm_unreachable("unsupported acc region op"); 127 } 128 }); 129 } 130 }; 131 132 } // end anonymous namespace 133