1ac9ee618SRazvan Lupusoru //===- LegalizeDataValues.cpp - -------------------------------------------===// 2ac9ee618SRazvan Lupusoru // 3ac9ee618SRazvan Lupusoru // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ac9ee618SRazvan Lupusoru // See https://llvm.org/LICENSE.txt for license information. 5ac9ee618SRazvan Lupusoru // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ac9ee618SRazvan Lupusoru // 7ac9ee618SRazvan Lupusoru //===----------------------------------------------------------------------===// 8ac9ee618SRazvan Lupusoru 9ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/Transforms/Passes.h" 10ac9ee618SRazvan Lupusoru 11ac9ee618SRazvan Lupusoru #include "mlir/Dialect/Func/IR/FuncOps.h" 12ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/OpenACC.h" 13ac9ee618SRazvan Lupusoru #include "mlir/Pass/Pass.h" 14ac9ee618SRazvan Lupusoru #include "mlir/Transforms/RegionUtils.h" 15ac9ee618SRazvan Lupusoru #include "llvm/Support/ErrorHandling.h" 16ac9ee618SRazvan Lupusoru 17ac9ee618SRazvan Lupusoru namespace mlir { 18ac9ee618SRazvan Lupusoru namespace acc { 19ac9ee618SRazvan Lupusoru #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION 20ac9ee618SRazvan Lupusoru #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" 21ac9ee618SRazvan Lupusoru } // namespace acc 22ac9ee618SRazvan Lupusoru } // namespace mlir 23ac9ee618SRazvan Lupusoru 24ac9ee618SRazvan Lupusoru using namespace mlir; 25ac9ee618SRazvan Lupusoru 26ac9ee618SRazvan Lupusoru namespace { 27ac9ee618SRazvan Lupusoru 28ac9ee618SRazvan Lupusoru static bool insideAccComputeRegion(mlir::Operation *op) { 29ac9ee618SRazvan Lupusoru mlir::Operation *parent{op->getParentOp()}; 30ac9ee618SRazvan Lupusoru while (parent) { 31ac9ee618SRazvan Lupusoru if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) { 32ac9ee618SRazvan Lupusoru return true; 33ac9ee618SRazvan Lupusoru } 34ac9ee618SRazvan Lupusoru parent = parent->getParentOp(); 35ac9ee618SRazvan Lupusoru } 36ac9ee618SRazvan Lupusoru return false; 37ac9ee618SRazvan Lupusoru } 38ac9ee618SRazvan Lupusoru 39ac9ee618SRazvan Lupusoru static void collectPtrs(mlir::ValueRange operands, 40ac9ee618SRazvan Lupusoru llvm::SmallVector<std::pair<Value, Value>> &values, 41ac9ee618SRazvan Lupusoru bool hostToDevice) { 42ac9ee618SRazvan Lupusoru for (auto operand : operands) { 43ac9ee618SRazvan Lupusoru Value varPtr = acc::getVarPtr(operand.getDefiningOp()); 44ac9ee618SRazvan Lupusoru Value accPtr = acc::getAccPtr(operand.getDefiningOp()); 45ac9ee618SRazvan Lupusoru if (varPtr && accPtr) { 46ac9ee618SRazvan Lupusoru if (hostToDevice) 47ac9ee618SRazvan Lupusoru values.push_back({varPtr, accPtr}); 48ac9ee618SRazvan Lupusoru else 49ac9ee618SRazvan Lupusoru values.push_back({accPtr, varPtr}); 50ac9ee618SRazvan Lupusoru } 51ac9ee618SRazvan Lupusoru } 52ac9ee618SRazvan Lupusoru } 53ac9ee618SRazvan Lupusoru 54ac9ee618SRazvan Lupusoru template <typename Op> 55ac9ee618SRazvan Lupusoru static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, 56ac9ee618SRazvan Lupusoru Region &outerRegion) { 57ac9ee618SRazvan Lupusoru for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 58ac9ee618SRazvan Lupusoru if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) { 59ac9ee618SRazvan Lupusoru if constexpr (std::is_same_v<Op, acc::DataOp> || 60ac9ee618SRazvan Lupusoru std::is_same_v<Op, acc::DeclareOp>) { 61ac9ee618SRazvan Lupusoru // For data construct regions, only replace uses in contained compute 62ac9ee618SRazvan Lupusoru // regions. 63ac9ee618SRazvan Lupusoru if (insideAccComputeRegion(use.getOwner())) { 64ac9ee618SRazvan Lupusoru use.set(replacement); 65ac9ee618SRazvan Lupusoru } 66ac9ee618SRazvan Lupusoru } else { 67ac9ee618SRazvan Lupusoru use.set(replacement); 68ac9ee618SRazvan Lupusoru } 69ac9ee618SRazvan Lupusoru } 70ac9ee618SRazvan Lupusoru } 71ac9ee618SRazvan Lupusoru } 72ac9ee618SRazvan Lupusoru 73ac9ee618SRazvan Lupusoru template <typename Op> 74ac9ee618SRazvan Lupusoru static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { 75ac9ee618SRazvan Lupusoru llvm::SmallVector<std::pair<Value, Value>> values; 76ac9ee618SRazvan Lupusoru 77ac9ee618SRazvan Lupusoru if constexpr (std::is_same_v<Op, acc::LoopOp>) { 78ac9ee618SRazvan Lupusoru collectPtrs(op.getReductionOperands(), values, hostToDevice); 79ac9ee618SRazvan Lupusoru collectPtrs(op.getPrivateOperands(), values, hostToDevice); 80ac9ee618SRazvan Lupusoru } else { 81ac9ee618SRazvan Lupusoru collectPtrs(op.getDataClauseOperands(), values, hostToDevice); 82ac9ee618SRazvan Lupusoru if constexpr (!std::is_same_v<Op, acc::KernelsOp> && 83ac9ee618SRazvan Lupusoru !std::is_same_v<Op, acc::DataOp> && 84ac9ee618SRazvan Lupusoru !std::is_same_v<Op, acc::DeclareOp>) { 85ac9ee618SRazvan Lupusoru collectPtrs(op.getReductionOperands(), values, hostToDevice); 86*c0a15970SRazvan Lupusoru collectPtrs(op.getPrivateOperands(), values, hostToDevice); 87*c0a15970SRazvan Lupusoru collectPtrs(op.getFirstprivateOperands(), values, hostToDevice); 88ac9ee618SRazvan Lupusoru } 89ac9ee618SRazvan Lupusoru } 90ac9ee618SRazvan Lupusoru 91ac9ee618SRazvan Lupusoru for (auto p : values) 92ac9ee618SRazvan Lupusoru replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p), 93ac9ee618SRazvan Lupusoru op.getRegion()); 94ac9ee618SRazvan Lupusoru } 95ac9ee618SRazvan Lupusoru 96ac9ee618SRazvan Lupusoru class LegalizeDataValuesInRegion 97ac9ee618SRazvan Lupusoru : public acc::impl::LegalizeDataValuesInRegionBase< 98ac9ee618SRazvan Lupusoru LegalizeDataValuesInRegion> { 99ac9ee618SRazvan Lupusoru public: 100ac9ee618SRazvan Lupusoru using LegalizeDataValuesInRegionBase< 101ac9ee618SRazvan Lupusoru LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase; 102ac9ee618SRazvan Lupusoru 103ac9ee618SRazvan Lupusoru void runOnOperation() override { 104ac9ee618SRazvan Lupusoru func::FuncOp funcOp = getOperation(); 105ac9ee618SRazvan Lupusoru bool replaceHostVsDevice = this->hostToDevice.getValue(); 106ac9ee618SRazvan Lupusoru 107ac9ee618SRazvan Lupusoru funcOp.walk([&](Operation *op) { 108ac9ee618SRazvan Lupusoru if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) && 109ac9ee618SRazvan Lupusoru !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) && 110ac9ee618SRazvan Lupusoru applyToAccDataConstruct)) 111ac9ee618SRazvan Lupusoru return; 112ac9ee618SRazvan Lupusoru 113ac9ee618SRazvan Lupusoru if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { 114ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); 115ac9ee618SRazvan Lupusoru } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { 116ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(serialOp, replaceHostVsDevice); 117ac9ee618SRazvan Lupusoru } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { 118ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); 119ac9ee618SRazvan Lupusoru } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) { 120ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(loopOp, replaceHostVsDevice); 121ac9ee618SRazvan Lupusoru } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) { 122ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(dataOp, replaceHostVsDevice); 123ac9ee618SRazvan Lupusoru } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) { 124ac9ee618SRazvan Lupusoru collectAndReplaceInRegion(declareOp, replaceHostVsDevice); 125ac9ee618SRazvan Lupusoru } else { 126ac9ee618SRazvan Lupusoru llvm_unreachable("unsupported acc region op"); 127ac9ee618SRazvan Lupusoru } 128ac9ee618SRazvan Lupusoru }); 129ac9ee618SRazvan Lupusoru } 130ac9ee618SRazvan Lupusoru }; 131ac9ee618SRazvan Lupusoru 132ac9ee618SRazvan Lupusoru } // end anonymous namespace 133