1113b8070SValentin Clement //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2113b8070SValentin Clement //
3113b8070SValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4113b8070SValentin Clement // See https://llvm.org/LICENSE.txt for license information.
5113b8070SValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6113b8070SValentin Clement //
7113b8070SValentin Clement //===----------------------------------------------------------------------===//
8113b8070SValentin Clement //
9113b8070SValentin Clement // This file implements a translation between the MLIR OpenACC dialect and LLVM
10113b8070SValentin Clement // IR.
11113b8070SValentin Clement //
12113b8070SValentin Clement //===----------------------------------------------------------------------===//
13113b8070SValentin Clement
14113b8070SValentin Clement #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15*b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
16113b8070SValentin Clement #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17113b8070SValentin Clement #include "mlir/Dialect/OpenACC/OpenACC.h"
18113b8070SValentin Clement #include "mlir/IR/BuiltinOps.h"
19113b8070SValentin Clement #include "mlir/IR/Operation.h"
20113b8070SValentin Clement #include "mlir/Support/LLVM.h"
212d373e4dSAkash Banerjee #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
22113b8070SValentin Clement #include "mlir/Target/LLVMIR/ModuleTranslation.h"
23113b8070SValentin Clement
24113b8070SValentin Clement #include "llvm/ADT/TypeSwitch.h"
25113b8070SValentin Clement #include "llvm/Frontend/OpenMP/OMPConstants.h"
26113b8070SValentin Clement #include "llvm/Support/FormatVariadic.h"
27113b8070SValentin Clement
28113b8070SValentin Clement using namespace mlir;
29113b8070SValentin Clement
30113b8070SValentin Clement using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
31113b8070SValentin Clement
32113b8070SValentin Clement //===----------------------------------------------------------------------===//
33113b8070SValentin Clement // Utility functions
34113b8070SValentin Clement //===----------------------------------------------------------------------===//
35113b8070SValentin Clement
36fe7ca1a9SValentin Clement /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
37fe7ca1a9SValentin Clement /// mapped to corresponding OpenACC flags.
38fe7ca1a9SValentin Clement static constexpr uint64_t kCreateFlag = 0x000;
39fe7ca1a9SValentin Clement static constexpr uint64_t kDeviceCopyinFlag = 0x001;
40fe7ca1a9SValentin Clement static constexpr uint64_t kHostCopyoutFlag = 0x002;
41fe7ca1a9SValentin Clement static constexpr uint64_t kPresentFlag = 0x1000;
42fe7ca1a9SValentin Clement static constexpr uint64_t kDeleteFlag = 0x008;
43d6929aaaSValentin Clement // Runtime extension to implement the OpenACC second reference counter.
44d6929aaaSValentin Clement static constexpr uint64_t kHoldFlag = 0x2000;
45ab5ff154SValentin Clement
46113b8070SValentin Clement /// Default value for the device id
471005ef44SValentin Clement static constexpr int64_t kDefaultDevice = -1;
48113b8070SValentin Clement
49113b8070SValentin Clement /// Create the location struct from the operation location information.
createSourceLocationInfo(OpenACCIRBuilder & builder,Operation * op)50ab5ff154SValentin Clement static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
51ab5ff154SValentin Clement Operation *op) {
52ab5ff154SValentin Clement auto loc = op->getLoc();
53ab5ff154SValentin Clement auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
54113b8070SValentin Clement StringRef funcName = funcOp ? funcOp.getName() : "unknown";
55944aa042SJohannes Doerfert uint32_t strLen;
562d373e4dSAkash Banerjee llvm::Constant *locStr = mlir::LLVM::createSourceLocStrFromLocation(
572d373e4dSAkash Banerjee loc, builder, funcName, strLen);
58944aa042SJohannes Doerfert return builder.getOrCreateIdent(locStr, strLen);
59113b8070SValentin Clement }
60113b8070SValentin Clement
61113b8070SValentin Clement /// Return the runtime function used to lower the given operation.
getAssociatedFunction(OpenACCIRBuilder & builder,Operation * op)62113b8070SValentin Clement static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
63ab5ff154SValentin Clement Operation *op) {
64ab5ff154SValentin Clement return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
65ab5ff154SValentin Clement .Case([&](acc::EnterDataOp) {
66113b8070SValentin Clement return builder.getOrCreateRuntimeFunctionPtr(
67113b8070SValentin Clement llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
68ab5ff154SValentin Clement })
69ab5ff154SValentin Clement .Case([&](acc::ExitDataOp) {
70ab5ff154SValentin Clement return builder.getOrCreateRuntimeFunctionPtr(
71ab5ff154SValentin Clement llvm::omp::OMPRTL___tgt_target_data_end_mapper);
721005ef44SValentin Clement })
731005ef44SValentin Clement .Case([&](acc::UpdateOp) {
741005ef44SValentin Clement return builder.getOrCreateRuntimeFunctionPtr(
751005ef44SValentin Clement llvm::omp::OMPRTL___tgt_target_data_update_mapper);
76ab5ff154SValentin Clement });
77113b8070SValentin Clement llvm_unreachable("Unknown OpenACC operation");
78113b8070SValentin Clement }
79113b8070SValentin Clement
80113b8070SValentin Clement /// Extract pointer, size and mapping information from operands
81113b8070SValentin Clement /// to populate the future functions arguments.
82113b8070SValentin Clement static LogicalResult
processOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,Operation * op,ValueRange operands,unsigned totalNbOperand,uint64_t operandFlag,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,unsigned & index,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)83113b8070SValentin Clement processOperands(llvm::IRBuilderBase &builder,
84ab5ff154SValentin Clement LLVM::ModuleTranslation &moduleTranslation, Operation *op,
85113b8070SValentin Clement ValueRange operands, unsigned totalNbOperand,
86113b8070SValentin Clement uint64_t operandFlag, SmallVector<uint64_t> &flags,
87fe7ca1a9SValentin Clement SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
88fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
89113b8070SValentin Clement OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
90113b8070SValentin Clement llvm::LLVMContext &ctx = builder.getContext();
917b9d73c2SPaulo Matos auto *i8PtrTy = llvm::PointerType::getUnqual(ctx);
92113b8070SValentin Clement auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
93113b8070SValentin Clement auto *i64Ty = llvm::Type::getInt64Ty(ctx);
94113b8070SValentin Clement auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
95113b8070SValentin Clement
96113b8070SValentin Clement for (Value data : operands) {
97113b8070SValentin Clement llvm::Value *dataValue = moduleTranslation.lookupValue(data);
98113b8070SValentin Clement
99113b8070SValentin Clement llvm::Value *dataPtrBase;
100113b8070SValentin Clement llvm::Value *dataPtr;
101113b8070SValentin Clement llvm::Value *dataSize;
102113b8070SValentin Clement
1035550c821STres Popp if (isa<LLVM::LLVMPointerType>(data.getType())) {
104113b8070SValentin Clement dataPtrBase = dataValue;
105113b8070SValentin Clement dataPtr = dataValue;
1062d373e4dSAkash Banerjee dataSize = accBuilder->getSizeInBytes(dataValue);
107113b8070SValentin Clement } else {
108ab5ff154SValentin Clement return op->emitOpError()
109113b8070SValentin Clement << "Data operand must be legalized before translation."
110113b8070SValentin Clement << "Unsupported type: " << data.getType();
111113b8070SValentin Clement }
112113b8070SValentin Clement
113113b8070SValentin Clement // Store base pointer extracted from operand into the i-th position of
114113b8070SValentin Clement // argBase.
115113b8070SValentin Clement llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
116fe7ca1a9SValentin Clement arrI8PtrTy, mapperAllocas.ArgsBase,
117fe7ca1a9SValentin Clement {builder.getInt32(0), builder.getInt32(index)});
118645b7795SYoungsuk Kim builder.CreateStore(dataPtrBase, ptrBaseGEP);
119113b8070SValentin Clement
120113b8070SValentin Clement // Store pointer extracted from operand into the i-th position of args.
121113b8070SValentin Clement llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
122fe7ca1a9SValentin Clement arrI8PtrTy, mapperAllocas.Args,
123fe7ca1a9SValentin Clement {builder.getInt32(0), builder.getInt32(index)});
124645b7795SYoungsuk Kim builder.CreateStore(dataPtr, ptrGEP);
125113b8070SValentin Clement
126113b8070SValentin Clement // Store size extracted from operand into the i-th position of argSizes.
127113b8070SValentin Clement llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
128fe7ca1a9SValentin Clement arrI64Ty, mapperAllocas.ArgSizes,
129fe7ca1a9SValentin Clement {builder.getInt32(0), builder.getInt32(index)});
130113b8070SValentin Clement builder.CreateStore(dataSize, sizeGEP);
131113b8070SValentin Clement
132113b8070SValentin Clement flags.push_back(operandFlag);
133113b8070SValentin Clement llvm::Constant *mapName =
1342d373e4dSAkash Banerjee mlir::LLVM::createMappingInformation(data.getLoc(), *accBuilder);
135113b8070SValentin Clement names.push_back(mapName);
136113b8070SValentin Clement ++index;
137113b8070SValentin Clement }
138113b8070SValentin Clement return success();
139113b8070SValentin Clement }
140113b8070SValentin Clement
141ab5ff154SValentin Clement /// Process data operands from acc::EnterDataOp
142fe7ca1a9SValentin Clement static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::EnterDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)143fe7ca1a9SValentin Clement processDataOperands(llvm::IRBuilderBase &builder,
144fe7ca1a9SValentin Clement LLVM::ModuleTranslation &moduleTranslation,
145ab5ff154SValentin Clement acc::EnterDataOp op, SmallVector<uint64_t> &flags,
146fe7ca1a9SValentin Clement SmallVectorImpl<llvm::Constant *> &names,
147fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
148ab5ff154SValentin Clement // TODO add `create_zero` and `attach` operands
149ab5ff154SValentin Clement
150cf846705SValentin Clement unsigned index = 0;
151cf846705SValentin Clement
152ab5ff154SValentin Clement // Create operands are handled as `alloc` call.
1539dec07f4SValentin Clement // Copyin operands are handled as `to` call.
1549dec07f4SValentin Clement llvm::SmallVector<mlir::Value> create, copyin;
1559dec07f4SValentin Clement for (mlir::Value dataOp : op.getDataClauseOperands()) {
1569dec07f4SValentin Clement if (auto createOp =
1579dec07f4SValentin Clement mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
1589dec07f4SValentin Clement create.push_back(createOp.getVarPtr());
1599dec07f4SValentin Clement } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
1609dec07f4SValentin Clement dataOp.getDefiningOp())) {
1619dec07f4SValentin Clement copyin.push_back(copyinOp.getVarPtr());
1629dec07f4SValentin Clement }
1639dec07f4SValentin Clement }
1649dec07f4SValentin Clement
1659dec07f4SValentin Clement auto nbTotalOperands = create.size() + copyin.size();
1669dec07f4SValentin Clement
1679dec07f4SValentin Clement // Create operands are handled as `alloc` call.
1689dec07f4SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, create,
1699dec07f4SValentin Clement nbTotalOperands, kCreateFlag, flags, names, index,
1709dec07f4SValentin Clement mapperAllocas)))
171ab5ff154SValentin Clement return failure();
172ab5ff154SValentin Clement
173ab5ff154SValentin Clement // Copyin operands are handled as `to` call.
1749dec07f4SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, copyin,
1759dec07f4SValentin Clement nbTotalOperands, kDeviceCopyinFlag, flags, names,
1769dec07f4SValentin Clement index, mapperAllocas)))
177ab5ff154SValentin Clement return failure();
178ab5ff154SValentin Clement
179ab5ff154SValentin Clement return success();
180ab5ff154SValentin Clement }
181ab5ff154SValentin Clement
182ab5ff154SValentin Clement /// Process data operands from acc::ExitDataOp
183fe7ca1a9SValentin Clement static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::ExitDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)184fe7ca1a9SValentin Clement processDataOperands(llvm::IRBuilderBase &builder,
185fe7ca1a9SValentin Clement LLVM::ModuleTranslation &moduleTranslation,
186ab5ff154SValentin Clement acc::ExitDataOp op, SmallVector<uint64_t> &flags,
187fe7ca1a9SValentin Clement SmallVectorImpl<llvm::Constant *> &names,
188fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
189ab5ff154SValentin Clement // TODO add `detach` operands
190ab5ff154SValentin Clement
191cf846705SValentin Clement unsigned index = 0;
192cf846705SValentin Clement
19315a480c0SValentin Clement llvm::SmallVector<mlir::Value> deleteOperands, copyoutOperands;
19415a480c0SValentin Clement for (mlir::Value dataOp : op.getDataClauseOperands()) {
19515a480c0SValentin Clement if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
19615a480c0SValentin Clement dataOp.getDefiningOp())) {
19715a480c0SValentin Clement for (auto &u : devicePtrOp.getAccPtr().getUses()) {
19815a480c0SValentin Clement if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner()))
19915a480c0SValentin Clement deleteOperands.push_back(devicePtrOp.getVarPtr());
20015a480c0SValentin Clement else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner()))
20115a480c0SValentin Clement copyoutOperands.push_back(devicePtrOp.getVarPtr());
20215a480c0SValentin Clement }
20315a480c0SValentin Clement }
20415a480c0SValentin Clement }
20515a480c0SValentin Clement
20615a480c0SValentin Clement auto nbTotalOperands = deleteOperands.size() + copyoutOperands.size();
20715a480c0SValentin Clement
208ab5ff154SValentin Clement // Delete operands are handled as `delete` call.
20915a480c0SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
21015a480c0SValentin Clement nbTotalOperands, kDeleteFlag, flags, names, index,
21115a480c0SValentin Clement mapperAllocas)))
212ab5ff154SValentin Clement return failure();
213ab5ff154SValentin Clement
214ab5ff154SValentin Clement // Copyout operands are handled as `from` call.
21515a480c0SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, copyoutOperands,
21615a480c0SValentin Clement nbTotalOperands, kHostCopyoutFlag, flags, names,
21715a480c0SValentin Clement index, mapperAllocas)))
2181005ef44SValentin Clement return failure();
2191005ef44SValentin Clement
2201005ef44SValentin Clement return success();
2211005ef44SValentin Clement }
2221005ef44SValentin Clement
2231005ef44SValentin Clement /// Process data operands from acc::UpdateOp
224fe7ca1a9SValentin Clement static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::UpdateOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)225fe7ca1a9SValentin Clement processDataOperands(llvm::IRBuilderBase &builder,
226fe7ca1a9SValentin Clement LLVM::ModuleTranslation &moduleTranslation,
2271005ef44SValentin Clement acc::UpdateOp op, SmallVector<uint64_t> &flags,
228fe7ca1a9SValentin Clement SmallVectorImpl<llvm::Constant *> &names,
229fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
230cf846705SValentin Clement unsigned index = 0;
2311005ef44SValentin Clement
2321005ef44SValentin Clement // Host operands are handled as `from` call.
233689afa88SValentin Clement // Device operands are handled as `to` call.
234689afa88SValentin Clement llvm::SmallVector<mlir::Value> from, to;
235689afa88SValentin Clement for (mlir::Value dataOp : op.getDataClauseOperands()) {
236689afa88SValentin Clement if (auto getDevicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
237689afa88SValentin Clement dataOp.getDefiningOp())) {
238689afa88SValentin Clement from.push_back(getDevicePtrOp.getVarPtr());
239689afa88SValentin Clement } else if (auto updateDeviceOp =
240689afa88SValentin Clement mlir::dyn_cast_or_null<acc::UpdateDeviceOp>(
241689afa88SValentin Clement dataOp.getDefiningOp())) {
242689afa88SValentin Clement to.push_back(updateDeviceOp.getVarPtr());
243689afa88SValentin Clement }
244689afa88SValentin Clement }
245689afa88SValentin Clement
246689afa88SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, from, from.size(),
247f9806b3eSRiver Riddle kHostCopyoutFlag, flags, names, index,
248f9806b3eSRiver Riddle mapperAllocas)))
2491005ef44SValentin Clement return failure();
2501005ef44SValentin Clement
251689afa88SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, to, to.size(),
252fe7ca1a9SValentin Clement kDeviceCopyinFlag, flags, names, index,
253fe7ca1a9SValentin Clement mapperAllocas)))
254ab5ff154SValentin Clement return failure();
255ab5ff154SValentin Clement return success();
256ab5ff154SValentin Clement }
257ab5ff154SValentin Clement
258113b8070SValentin Clement //===----------------------------------------------------------------------===//
259113b8070SValentin Clement // Conversion functions
260113b8070SValentin Clement //===----------------------------------------------------------------------===//
261113b8070SValentin Clement
262fe7ca1a9SValentin Clement /// Converts an OpenACC data operation into LLVM IR.
convertDataOp(acc::DataOp & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)263fe7ca1a9SValentin Clement static LogicalResult convertDataOp(acc::DataOp &op,
264fe7ca1a9SValentin Clement llvm::IRBuilderBase &builder,
265fe7ca1a9SValentin Clement LLVM::ModuleTranslation &moduleTranslation) {
266fe7ca1a9SValentin Clement llvm::LLVMContext &ctx = builder.getContext();
267fe7ca1a9SValentin Clement auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
268fe7ca1a9SValentin Clement llvm::Function *enclosingFunction =
269fe7ca1a9SValentin Clement moduleTranslation.lookupFunction(enclosingFuncOp.getName());
270fe7ca1a9SValentin Clement
271fe7ca1a9SValentin Clement OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
272fe7ca1a9SValentin Clement
273fe7ca1a9SValentin Clement llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
274fe7ca1a9SValentin Clement
275fe7ca1a9SValentin Clement llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
276fe7ca1a9SValentin Clement llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
277fe7ca1a9SValentin Clement
278fe7ca1a9SValentin Clement llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
279fe7ca1a9SValentin Clement llvm::omp::OMPRTL___tgt_target_data_end_mapper);
280fe7ca1a9SValentin Clement
281fe7ca1a9SValentin Clement // Number of arguments in the data operation.
282fe7ca1a9SValentin Clement unsigned totalNbOperand = op.getNumDataOperands();
283fe7ca1a9SValentin Clement
284fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
285fe7ca1a9SValentin Clement OpenACCIRBuilder::InsertPointTy allocaIP(
286fe7ca1a9SValentin Clement &enclosingFunction->getEntryBlock(),
287fe7ca1a9SValentin Clement enclosingFunction->getEntryBlock().getFirstInsertionPt());
288fe7ca1a9SValentin Clement accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
289fe7ca1a9SValentin Clement mapperAllocas);
290fe7ca1a9SValentin Clement
291fe7ca1a9SValentin Clement SmallVector<uint64_t> flags;
292fe7ca1a9SValentin Clement SmallVector<llvm::Constant *> names;
293fe7ca1a9SValentin Clement unsigned index = 0;
294fe7ca1a9SValentin Clement
295fe7ca1a9SValentin Clement // TODO handle no_create, deviceptr and attach operands.
296fe7ca1a9SValentin Clement
29746e1b095SValentin Clement llvm::SmallVector<mlir::Value> copyin, copyout, create, present,
29846e1b095SValentin Clement deleteOperands;
29946e1b095SValentin Clement for (mlir::Value dataOp : op.getDataClauseOperands()) {
30046e1b095SValentin Clement if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
30146e1b095SValentin Clement dataOp.getDefiningOp())) {
30246e1b095SValentin Clement for (auto &u : devicePtrOp.getAccPtr().getUses()) {
30346e1b095SValentin Clement if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner())) {
30446e1b095SValentin Clement deleteOperands.push_back(devicePtrOp.getVarPtr());
30546e1b095SValentin Clement } else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner())) {
30646e1b095SValentin Clement // TODO copyout zero currenlty handled as copyout. Update when
30746e1b095SValentin Clement // extension available.
30846e1b095SValentin Clement copyout.push_back(devicePtrOp.getVarPtr());
30946e1b095SValentin Clement }
31046e1b095SValentin Clement }
31146e1b095SValentin Clement } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
31246e1b095SValentin Clement dataOp.getDefiningOp())) {
313fe7ca1a9SValentin Clement // TODO copyin readonly currenlty handled as copyin. Update when extension
314fe7ca1a9SValentin Clement // available.
31546e1b095SValentin Clement copyin.push_back(copyinOp.getVarPtr());
31646e1b095SValentin Clement } else if (auto createOp = mlir::dyn_cast_or_null<acc::CreateOp>(
31746e1b095SValentin Clement dataOp.getDefiningOp())) {
318fe7ca1a9SValentin Clement // TODO create zero currenlty handled as create. Update when extension
319fe7ca1a9SValentin Clement // available.
32046e1b095SValentin Clement create.push_back(createOp.getVarPtr());
32146e1b095SValentin Clement } else if (auto presentOp = mlir::dyn_cast_or_null<acc::PresentOp>(
32246e1b095SValentin Clement dataOp.getDefiningOp())) {
32346e1b095SValentin Clement present.push_back(createOp.getVarPtr());
32446e1b095SValentin Clement }
32546e1b095SValentin Clement }
32646e1b095SValentin Clement
32746e1b095SValentin Clement auto nbTotalOperands = copyin.size() + copyout.size() + create.size() +
32846e1b095SValentin Clement present.size() + deleteOperands.size();
32946e1b095SValentin Clement
33046e1b095SValentin Clement // Copyin operands are handled as `to` call.
33146e1b095SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, copyin,
33246e1b095SValentin Clement nbTotalOperands, kDeviceCopyinFlag | kHoldFlag,
33346e1b095SValentin Clement flags, names, index, mapperAllocas)))
33446e1b095SValentin Clement return failure();
33546e1b095SValentin Clement
33646e1b095SValentin Clement // Delete operands are handled as `delete` call.
33746e1b095SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
33846e1b095SValentin Clement nbTotalOperands, kDeleteFlag, flags, names, index,
339d6929aaaSValentin Clement mapperAllocas)))
340fe7ca1a9SValentin Clement return failure();
341fe7ca1a9SValentin Clement
34246e1b095SValentin Clement // Copyout operands are handled as `from` call.
34346e1b095SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, copyout,
34446e1b095SValentin Clement nbTotalOperands, kHostCopyoutFlag | kHoldFlag,
34546e1b095SValentin Clement flags, names, index, mapperAllocas)))
34646e1b095SValentin Clement return failure();
34746e1b095SValentin Clement
34846e1b095SValentin Clement // Create operands are handled as `alloc` call.
34946e1b095SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, create,
35046e1b095SValentin Clement nbTotalOperands, kCreateFlag | kHoldFlag, flags,
35146e1b095SValentin Clement names, index, mapperAllocas)))
35246e1b095SValentin Clement return failure();
35346e1b095SValentin Clement
35446e1b095SValentin Clement if (failed(processOperands(builder, moduleTranslation, op, present,
35546e1b095SValentin Clement nbTotalOperands, kPresentFlag | kHoldFlag, flags,
35646e1b095SValentin Clement names, index, mapperAllocas)))
357fe7ca1a9SValentin Clement return failure();
358fe7ca1a9SValentin Clement
359fe7ca1a9SValentin Clement llvm::GlobalVariable *maptypes =
360fe7ca1a9SValentin Clement accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
361fe7ca1a9SValentin Clement llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
362fe7ca1a9SValentin Clement llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
363fe7ca1a9SValentin Clement maptypes, /*Idx0=*/0, /*Idx1=*/0);
364fe7ca1a9SValentin Clement
365fe7ca1a9SValentin Clement llvm::GlobalVariable *mapnames =
366fe7ca1a9SValentin Clement accBuilder->createOffloadMapnames(names, ".offload_mapnames");
367fe7ca1a9SValentin Clement llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
3687b9d73c2SPaulo Matos llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
369fe7ca1a9SValentin Clement mapnames, /*Idx0=*/0, /*Idx1=*/0);
370fe7ca1a9SValentin Clement
371fe7ca1a9SValentin Clement // Create call to start the data region.
372fe7ca1a9SValentin Clement accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
373fe7ca1a9SValentin Clement maptypesArg, mapnamesArg, mapperAllocas,
374fe7ca1a9SValentin Clement kDefaultDevice, totalNbOperand);
375fe7ca1a9SValentin Clement
376fe7ca1a9SValentin Clement // Convert the region.
377fe7ca1a9SValentin Clement llvm::BasicBlock *entryBlock = nullptr;
378fe7ca1a9SValentin Clement
379f9806b3eSRiver Riddle for (Block &bb : op.getRegion()) {
380fe7ca1a9SValentin Clement llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
381fe7ca1a9SValentin Clement ctx, "acc.data", builder.GetInsertBlock()->getParent());
382fe7ca1a9SValentin Clement if (entryBlock == nullptr)
383fe7ca1a9SValentin Clement entryBlock = llvmBB;
384fe7ca1a9SValentin Clement moduleTranslation.mapBlock(&bb, llvmBB);
385fe7ca1a9SValentin Clement }
386fe7ca1a9SValentin Clement
387fe7ca1a9SValentin Clement auto afterDataRegion = builder.saveIP();
388fe7ca1a9SValentin Clement
389fe7ca1a9SValentin Clement llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
390fe7ca1a9SValentin Clement
391fe7ca1a9SValentin Clement builder.restoreIP(afterDataRegion);
392fe7ca1a9SValentin Clement llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
393fe7ca1a9SValentin Clement ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
394fe7ca1a9SValentin Clement
395e919df57SChristian Ulmann SetVector<Block *> blocks = getBlocksSortedByDominance(op.getRegion());
396fe7ca1a9SValentin Clement for (Block *bb : blocks) {
397fe7ca1a9SValentin Clement llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
398fe7ca1a9SValentin Clement if (bb->isEntryBlock()) {
399fe7ca1a9SValentin Clement assert(sourceTerminator->getNumSuccessors() == 1 &&
400fe7ca1a9SValentin Clement "provided entry block has multiple successors");
401fe7ca1a9SValentin Clement sourceTerminator->setSuccessor(0, llvmBB);
402fe7ca1a9SValentin Clement }
403fe7ca1a9SValentin Clement
404fe7ca1a9SValentin Clement if (failed(
405fe7ca1a9SValentin Clement moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
406fe7ca1a9SValentin Clement return failure();
407fe7ca1a9SValentin Clement }
408fe7ca1a9SValentin Clement
409fe7ca1a9SValentin Clement if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
410fe7ca1a9SValentin Clement builder.CreateBr(endDataBlock);
411fe7ca1a9SValentin Clement }
412fe7ca1a9SValentin Clement
413fe7ca1a9SValentin Clement // Create call to end the data region.
414fe7ca1a9SValentin Clement builder.SetInsertPoint(endDataBlock);
415fe7ca1a9SValentin Clement accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
416fe7ca1a9SValentin Clement maptypesArg, mapnamesArg, mapperAllocas,
417fe7ca1a9SValentin Clement kDefaultDevice, totalNbOperand);
418fe7ca1a9SValentin Clement
419fe7ca1a9SValentin Clement return success();
420fe7ca1a9SValentin Clement }
421fe7ca1a9SValentin Clement
422ab5ff154SValentin Clement /// Converts an OpenACC standalone data operation into LLVM IR.
423ab5ff154SValentin Clement template <typename OpTy>
424113b8070SValentin Clement static LogicalResult
convertStandaloneDataOp(OpTy & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)425ab5ff154SValentin Clement convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
426113b8070SValentin Clement LLVM::ModuleTranslation &moduleTranslation) {
427ab5ff154SValentin Clement auto enclosingFuncOp =
428ab5ff154SValentin Clement op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
429113b8070SValentin Clement llvm::Function *enclosingFunction =
430113b8070SValentin Clement moduleTranslation.lookupFunction(enclosingFuncOp.getName());
431113b8070SValentin Clement
432113b8070SValentin Clement OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
433113b8070SValentin Clement
434ab5ff154SValentin Clement auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
435113b8070SValentin Clement auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
436113b8070SValentin Clement
437113b8070SValentin Clement // Number of arguments in the enter_data operation.
438ab5ff154SValentin Clement unsigned totalNbOperand = op.getNumDataOperands();
439113b8070SValentin Clement
440113b8070SValentin Clement llvm::LLVMContext &ctx = builder.getContext();
441fe7ca1a9SValentin Clement
442fe7ca1a9SValentin Clement struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
443fe7ca1a9SValentin Clement OpenACCIRBuilder::InsertPointTy allocaIP(
444113b8070SValentin Clement &enclosingFunction->getEntryBlock(),
445113b8070SValentin Clement enclosingFunction->getEntryBlock().getFirstInsertionPt());
446fe7ca1a9SValentin Clement accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
447fe7ca1a9SValentin Clement mapperAllocas);
448113b8070SValentin Clement
449113b8070SValentin Clement SmallVector<uint64_t> flags;
450113b8070SValentin Clement SmallVector<llvm::Constant *> names;
451113b8070SValentin Clement
452ab5ff154SValentin Clement if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
453fe7ca1a9SValentin Clement mapperAllocas)))
454113b8070SValentin Clement return failure();
455113b8070SValentin Clement
456113b8070SValentin Clement llvm::GlobalVariable *maptypes =
457113b8070SValentin Clement accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
458113b8070SValentin Clement llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
459113b8070SValentin Clement llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
460113b8070SValentin Clement maptypes, /*Idx0=*/0, /*Idx1=*/0);
461113b8070SValentin Clement
462113b8070SValentin Clement llvm::GlobalVariable *mapnames =
463113b8070SValentin Clement accBuilder->createOffloadMapnames(names, ".offload_mapnames");
464113b8070SValentin Clement llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
4657b9d73c2SPaulo Matos llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
466113b8070SValentin Clement mapnames, /*Idx0=*/0, /*Idx1=*/0);
467113b8070SValentin Clement
468fe7ca1a9SValentin Clement accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
469fe7ca1a9SValentin Clement maptypesArg, mapnamesArg, mapperAllocas,
470fe7ca1a9SValentin Clement kDefaultDevice, totalNbOperand);
471113b8070SValentin Clement
472113b8070SValentin Clement return success();
473113b8070SValentin Clement }
474113b8070SValentin Clement
475113b8070SValentin Clement namespace {
476113b8070SValentin Clement
477113b8070SValentin Clement /// Implementation of the dialect interface that converts operations belonging
478113b8070SValentin Clement /// to the OpenACC dialect to LLVM IR.
479113b8070SValentin Clement class OpenACCDialectLLVMIRTranslationInterface
480113b8070SValentin Clement : public LLVMTranslationDialectInterface {
481113b8070SValentin Clement public:
482113b8070SValentin Clement using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
483113b8070SValentin Clement
484113b8070SValentin Clement /// Translates the given operation to LLVM IR using the provided IR builder
485113b8070SValentin Clement /// and saving the state in `moduleTranslation`.
486113b8070SValentin Clement LogicalResult
487113b8070SValentin Clement convertOperation(Operation *op, llvm::IRBuilderBase &builder,
488113b8070SValentin Clement LLVM::ModuleTranslation &moduleTranslation) const final;
489113b8070SValentin Clement };
490113b8070SValentin Clement
491be0a7e9fSMehdi Amini } // namespace
492113b8070SValentin Clement
493113b8070SValentin Clement /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
494113b8070SValentin Clement /// (including OpenACC runtime calls).
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const495113b8070SValentin Clement LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
496113b8070SValentin Clement Operation *op, llvm::IRBuilderBase &builder,
497113b8070SValentin Clement LLVM::ModuleTranslation &moduleTranslation) const {
498113b8070SValentin Clement
499113b8070SValentin Clement return llvm::TypeSwitch<Operation *, LogicalResult>(op)
500fe7ca1a9SValentin Clement .Case([&](acc::DataOp dataOp) {
501fe7ca1a9SValentin Clement return convertDataOp(dataOp, builder, moduleTranslation);
502fe7ca1a9SValentin Clement })
503ab5ff154SValentin Clement .Case([&](acc::EnterDataOp enterDataOp) {
504ab5ff154SValentin Clement return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
505ab5ff154SValentin Clement moduleTranslation);
506ab5ff154SValentin Clement })
507ab5ff154SValentin Clement .Case([&](acc::ExitDataOp exitDataOp) {
508ab5ff154SValentin Clement return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
509ab5ff154SValentin Clement moduleTranslation);
510113b8070SValentin Clement })
5111005ef44SValentin Clement .Case([&](acc::UpdateOp updateOp) {
5121005ef44SValentin Clement return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
5131005ef44SValentin Clement moduleTranslation);
5141005ef44SValentin Clement })
515fe7ca1a9SValentin Clement .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
516fe7ca1a9SValentin Clement // `yield` and `terminator` can be just omitted. The block structure was
517fe7ca1a9SValentin Clement // created in the function that handles their parent operation.
518fe7ca1a9SValentin Clement assert(op->getNumOperands() == 0 &&
519fe7ca1a9SValentin Clement "unexpected OpenACC terminator with operands");
520fe7ca1a9SValentin Clement return success();
521fe7ca1a9SValentin Clement })
52215a480c0SValentin Clement .Case<acc::CreateOp, acc::CopyinOp, acc::CopyoutOp, acc::DeleteOp,
52315a480c0SValentin Clement acc::UpdateDeviceOp, acc::GetDevicePtrOp>([](auto op) {
524689afa88SValentin Clement // NOP
525689afa88SValentin Clement return success();
526689afa88SValentin Clement })
527113b8070SValentin Clement .Default([&](Operation *op) {
528113b8070SValentin Clement return op->emitError("unsupported OpenACC operation: ")
529113b8070SValentin Clement << op->getName();
530113b8070SValentin Clement });
531113b8070SValentin Clement }
532113b8070SValentin Clement
registerOpenACCDialectTranslation(DialectRegistry & registry)533113b8070SValentin Clement void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) {
534113b8070SValentin Clement registry.insert<acc::OpenACCDialect>();
53577eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
53677eee579SRiver Riddle dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
53777eee579SRiver Riddle });
538113b8070SValentin Clement }
539113b8070SValentin Clement
registerOpenACCDialectTranslation(MLIRContext & context)540113b8070SValentin Clement void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
541113b8070SValentin Clement DialectRegistry registry;
542113b8070SValentin Clement registerOpenACCDialectTranslation(registry);
543113b8070SValentin Clement context.appendDialectRegistry(registry);
544113b8070SValentin Clement }
545