xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Analysis/TopologicalSortUtils.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
22 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
23 
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Frontend/OpenMP/OMPConstants.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
37 /// mapped to corresponding OpenACC flags.
38 static constexpr uint64_t kCreateFlag = 0x000;
39 static constexpr uint64_t kDeviceCopyinFlag = 0x001;
40 static constexpr uint64_t kHostCopyoutFlag = 0x002;
41 static constexpr uint64_t kPresentFlag = 0x1000;
42 static constexpr uint64_t kDeleteFlag = 0x008;
43 // Runtime extension to implement the OpenACC second reference counter.
44 static constexpr uint64_t kHoldFlag = 0x2000;
45 
46 /// Default value for the device id
47 static constexpr int64_t kDefaultDevice = -1;
48 
49 /// Create the location struct from the operation location information.
createSourceLocationInfo(OpenACCIRBuilder & builder,Operation * op)50 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
51                                              Operation *op) {
52   auto loc = op->getLoc();
53   auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
54   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
55   uint32_t strLen;
56   llvm::Constant *locStr = mlir::LLVM::createSourceLocStrFromLocation(
57       loc, builder, funcName, strLen);
58   return builder.getOrCreateIdent(locStr, strLen);
59 }
60 
61 /// Return the runtime function used to lower the given operation.
getAssociatedFunction(OpenACCIRBuilder & builder,Operation * op)62 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
63                                              Operation *op) {
64   return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
65       .Case([&](acc::EnterDataOp) {
66         return builder.getOrCreateRuntimeFunctionPtr(
67             llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
68       })
69       .Case([&](acc::ExitDataOp) {
70         return builder.getOrCreateRuntimeFunctionPtr(
71             llvm::omp::OMPRTL___tgt_target_data_end_mapper);
72       })
73       .Case([&](acc::UpdateOp) {
74         return builder.getOrCreateRuntimeFunctionPtr(
75             llvm::omp::OMPRTL___tgt_target_data_update_mapper);
76       });
77   llvm_unreachable("Unknown OpenACC operation");
78 }
79 
80 /// Extract pointer, size and mapping information from operands
81 /// to populate the future functions arguments.
82 static LogicalResult
processOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,Operation * op,ValueRange operands,unsigned totalNbOperand,uint64_t operandFlag,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,unsigned & index,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)83 processOperands(llvm::IRBuilderBase &builder,
84                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
85                 ValueRange operands, unsigned totalNbOperand,
86                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
87                 SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
88                 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
89   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
90   llvm::LLVMContext &ctx = builder.getContext();
91   auto *i8PtrTy = llvm::PointerType::getUnqual(ctx);
92   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
93   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
94   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
95 
96   for (Value data : operands) {
97     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
98 
99     llvm::Value *dataPtrBase;
100     llvm::Value *dataPtr;
101     llvm::Value *dataSize;
102 
103     if (isa<LLVM::LLVMPointerType>(data.getType())) {
104       dataPtrBase = dataValue;
105       dataPtr = dataValue;
106       dataSize = accBuilder->getSizeInBytes(dataValue);
107     } else {
108       return op->emitOpError()
109              << "Data operand must be legalized before translation."
110              << "Unsupported type: " << data.getType();
111     }
112 
113     // Store base pointer extracted from operand into the i-th position of
114     // argBase.
115     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
116         arrI8PtrTy, mapperAllocas.ArgsBase,
117         {builder.getInt32(0), builder.getInt32(index)});
118     builder.CreateStore(dataPtrBase, ptrBaseGEP);
119 
120     // Store pointer extracted from operand into the i-th position of args.
121     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
122         arrI8PtrTy, mapperAllocas.Args,
123         {builder.getInt32(0), builder.getInt32(index)});
124     builder.CreateStore(dataPtr, ptrGEP);
125 
126     // Store size extracted from operand into the i-th position of argSizes.
127     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
128         arrI64Ty, mapperAllocas.ArgSizes,
129         {builder.getInt32(0), builder.getInt32(index)});
130     builder.CreateStore(dataSize, sizeGEP);
131 
132     flags.push_back(operandFlag);
133     llvm::Constant *mapName =
134         mlir::LLVM::createMappingInformation(data.getLoc(), *accBuilder);
135     names.push_back(mapName);
136     ++index;
137   }
138   return success();
139 }
140 
141 /// Process data operands from acc::EnterDataOp
142 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::EnterDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)143 processDataOperands(llvm::IRBuilderBase &builder,
144                     LLVM::ModuleTranslation &moduleTranslation,
145                     acc::EnterDataOp op, SmallVector<uint64_t> &flags,
146                     SmallVectorImpl<llvm::Constant *> &names,
147                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
148   // TODO add `create_zero` and `attach` operands
149 
150   unsigned index = 0;
151 
152   // Create operands are handled as `alloc` call.
153   // Copyin operands are handled as `to` call.
154   llvm::SmallVector<mlir::Value> create, copyin;
155   for (mlir::Value dataOp : op.getDataClauseOperands()) {
156     if (auto createOp =
157             mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
158       create.push_back(createOp.getVarPtr());
159     } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
160                    dataOp.getDefiningOp())) {
161       copyin.push_back(copyinOp.getVarPtr());
162     }
163   }
164 
165   auto nbTotalOperands = create.size() + copyin.size();
166 
167   // Create operands are handled as `alloc` call.
168   if (failed(processOperands(builder, moduleTranslation, op, create,
169                              nbTotalOperands, kCreateFlag, flags, names, index,
170                              mapperAllocas)))
171     return failure();
172 
173   // Copyin operands are handled as `to` call.
174   if (failed(processOperands(builder, moduleTranslation, op, copyin,
175                              nbTotalOperands, kDeviceCopyinFlag, flags, names,
176                              index, mapperAllocas)))
177     return failure();
178 
179   return success();
180 }
181 
182 /// Process data operands from acc::ExitDataOp
183 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::ExitDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)184 processDataOperands(llvm::IRBuilderBase &builder,
185                     LLVM::ModuleTranslation &moduleTranslation,
186                     acc::ExitDataOp op, SmallVector<uint64_t> &flags,
187                     SmallVectorImpl<llvm::Constant *> &names,
188                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
189   // TODO add `detach` operands
190 
191   unsigned index = 0;
192 
193   llvm::SmallVector<mlir::Value> deleteOperands, copyoutOperands;
194   for (mlir::Value dataOp : op.getDataClauseOperands()) {
195     if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
196             dataOp.getDefiningOp())) {
197       for (auto &u : devicePtrOp.getAccPtr().getUses()) {
198         if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner()))
199           deleteOperands.push_back(devicePtrOp.getVarPtr());
200         else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner()))
201           copyoutOperands.push_back(devicePtrOp.getVarPtr());
202       }
203     }
204   }
205 
206   auto nbTotalOperands = deleteOperands.size() + copyoutOperands.size();
207 
208   // Delete operands are handled as `delete` call.
209   if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
210                              nbTotalOperands, kDeleteFlag, flags, names, index,
211                              mapperAllocas)))
212     return failure();
213 
214   // Copyout operands are handled as `from` call.
215   if (failed(processOperands(builder, moduleTranslation, op, copyoutOperands,
216                              nbTotalOperands, kHostCopyoutFlag, flags, names,
217                              index, mapperAllocas)))
218     return failure();
219 
220   return success();
221 }
222 
223 /// Process data operands from acc::UpdateOp
224 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::UpdateOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)225 processDataOperands(llvm::IRBuilderBase &builder,
226                     LLVM::ModuleTranslation &moduleTranslation,
227                     acc::UpdateOp op, SmallVector<uint64_t> &flags,
228                     SmallVectorImpl<llvm::Constant *> &names,
229                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
230   unsigned index = 0;
231 
232   // Host operands are handled as `from` call.
233   // Device operands are handled as `to` call.
234   llvm::SmallVector<mlir::Value> from, to;
235   for (mlir::Value dataOp : op.getDataClauseOperands()) {
236     if (auto getDevicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
237             dataOp.getDefiningOp())) {
238       from.push_back(getDevicePtrOp.getVarPtr());
239     } else if (auto updateDeviceOp =
240                    mlir::dyn_cast_or_null<acc::UpdateDeviceOp>(
241                        dataOp.getDefiningOp())) {
242       to.push_back(updateDeviceOp.getVarPtr());
243     }
244   }
245 
246   if (failed(processOperands(builder, moduleTranslation, op, from, from.size(),
247                              kHostCopyoutFlag, flags, names, index,
248                              mapperAllocas)))
249     return failure();
250 
251   if (failed(processOperands(builder, moduleTranslation, op, to, to.size(),
252                              kDeviceCopyinFlag, flags, names, index,
253                              mapperAllocas)))
254     return failure();
255   return success();
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // Conversion functions
260 //===----------------------------------------------------------------------===//
261 
262 /// Converts an OpenACC data operation into LLVM IR.
convertDataOp(acc::DataOp & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)263 static LogicalResult convertDataOp(acc::DataOp &op,
264                                    llvm::IRBuilderBase &builder,
265                                    LLVM::ModuleTranslation &moduleTranslation) {
266   llvm::LLVMContext &ctx = builder.getContext();
267   auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
268   llvm::Function *enclosingFunction =
269       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
270 
271   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
272 
273   llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
274 
275   llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
276       llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
277 
278   llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
279       llvm::omp::OMPRTL___tgt_target_data_end_mapper);
280 
281   // Number of arguments in the data operation.
282   unsigned totalNbOperand = op.getNumDataOperands();
283 
284   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
285   OpenACCIRBuilder::InsertPointTy allocaIP(
286       &enclosingFunction->getEntryBlock(),
287       enclosingFunction->getEntryBlock().getFirstInsertionPt());
288   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
289                                   mapperAllocas);
290 
291   SmallVector<uint64_t> flags;
292   SmallVector<llvm::Constant *> names;
293   unsigned index = 0;
294 
295   // TODO handle no_create, deviceptr and attach operands.
296 
297   llvm::SmallVector<mlir::Value> copyin, copyout, create, present,
298       deleteOperands;
299   for (mlir::Value dataOp : op.getDataClauseOperands()) {
300     if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
301             dataOp.getDefiningOp())) {
302       for (auto &u : devicePtrOp.getAccPtr().getUses()) {
303         if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner())) {
304           deleteOperands.push_back(devicePtrOp.getVarPtr());
305         } else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner())) {
306           // TODO copyout zero currenlty handled as copyout. Update when
307           // extension available.
308           copyout.push_back(devicePtrOp.getVarPtr());
309         }
310       }
311     } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
312                    dataOp.getDefiningOp())) {
313       // TODO copyin readonly currenlty handled as copyin. Update when extension
314       // available.
315       copyin.push_back(copyinOp.getVarPtr());
316     } else if (auto createOp = mlir::dyn_cast_or_null<acc::CreateOp>(
317                    dataOp.getDefiningOp())) {
318       // TODO create zero currenlty handled as create. Update when extension
319       // available.
320       create.push_back(createOp.getVarPtr());
321     } else if (auto presentOp = mlir::dyn_cast_or_null<acc::PresentOp>(
322                    dataOp.getDefiningOp())) {
323       present.push_back(createOp.getVarPtr());
324     }
325   }
326 
327   auto nbTotalOperands = copyin.size() + copyout.size() + create.size() +
328                          present.size() + deleteOperands.size();
329 
330   // Copyin operands are handled as `to` call.
331   if (failed(processOperands(builder, moduleTranslation, op, copyin,
332                              nbTotalOperands, kDeviceCopyinFlag | kHoldFlag,
333                              flags, names, index, mapperAllocas)))
334     return failure();
335 
336   // Delete operands are handled as `delete` call.
337   if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
338                              nbTotalOperands, kDeleteFlag, flags, names, index,
339                              mapperAllocas)))
340     return failure();
341 
342   // Copyout operands are handled as `from` call.
343   if (failed(processOperands(builder, moduleTranslation, op, copyout,
344                              nbTotalOperands, kHostCopyoutFlag | kHoldFlag,
345                              flags, names, index, mapperAllocas)))
346     return failure();
347 
348   // Create operands are handled as `alloc` call.
349   if (failed(processOperands(builder, moduleTranslation, op, create,
350                              nbTotalOperands, kCreateFlag | kHoldFlag, flags,
351                              names, index, mapperAllocas)))
352     return failure();
353 
354   if (failed(processOperands(builder, moduleTranslation, op, present,
355                              nbTotalOperands, kPresentFlag | kHoldFlag, flags,
356                              names, index, mapperAllocas)))
357     return failure();
358 
359   llvm::GlobalVariable *maptypes =
360       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
361   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
362       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
363       maptypes, /*Idx0=*/0, /*Idx1=*/0);
364 
365   llvm::GlobalVariable *mapnames =
366       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
367   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
368       llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
369       mapnames, /*Idx0=*/0, /*Idx1=*/0);
370 
371   // Create call to start the data region.
372   accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
373                              maptypesArg, mapnamesArg, mapperAllocas,
374                              kDefaultDevice, totalNbOperand);
375 
376   // Convert the region.
377   llvm::BasicBlock *entryBlock = nullptr;
378 
379   for (Block &bb : op.getRegion()) {
380     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
381         ctx, "acc.data", builder.GetInsertBlock()->getParent());
382     if (entryBlock == nullptr)
383       entryBlock = llvmBB;
384     moduleTranslation.mapBlock(&bb, llvmBB);
385   }
386 
387   auto afterDataRegion = builder.saveIP();
388 
389   llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
390 
391   builder.restoreIP(afterDataRegion);
392   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
393       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
394 
395   SetVector<Block *> blocks = getBlocksSortedByDominance(op.getRegion());
396   for (Block *bb : blocks) {
397     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
398     if (bb->isEntryBlock()) {
399       assert(sourceTerminator->getNumSuccessors() == 1 &&
400              "provided entry block has multiple successors");
401       sourceTerminator->setSuccessor(0, llvmBB);
402     }
403 
404     if (failed(
405             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
406       return failure();
407     }
408 
409     if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
410       builder.CreateBr(endDataBlock);
411   }
412 
413   // Create call to end the data region.
414   builder.SetInsertPoint(endDataBlock);
415   accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
416                              maptypesArg, mapnamesArg, mapperAllocas,
417                              kDefaultDevice, totalNbOperand);
418 
419   return success();
420 }
421 
422 /// Converts an OpenACC standalone data operation into LLVM IR.
423 template <typename OpTy>
424 static LogicalResult
convertStandaloneDataOp(OpTy & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)425 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
426                         LLVM::ModuleTranslation &moduleTranslation) {
427   auto enclosingFuncOp =
428       op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
429   llvm::Function *enclosingFunction =
430       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
431 
432   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
433 
434   auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
435   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
436 
437   // Number of arguments in the enter_data operation.
438   unsigned totalNbOperand = op.getNumDataOperands();
439 
440   llvm::LLVMContext &ctx = builder.getContext();
441 
442   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
443   OpenACCIRBuilder::InsertPointTy allocaIP(
444       &enclosingFunction->getEntryBlock(),
445       enclosingFunction->getEntryBlock().getFirstInsertionPt());
446   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
447                                   mapperAllocas);
448 
449   SmallVector<uint64_t> flags;
450   SmallVector<llvm::Constant *> names;
451 
452   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
453                                  mapperAllocas)))
454     return failure();
455 
456   llvm::GlobalVariable *maptypes =
457       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
458   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
459       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
460       maptypes, /*Idx0=*/0, /*Idx1=*/0);
461 
462   llvm::GlobalVariable *mapnames =
463       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
464   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
465       llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
466       mapnames, /*Idx0=*/0, /*Idx1=*/0);
467 
468   accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
469                              maptypesArg, mapnamesArg, mapperAllocas,
470                              kDefaultDevice, totalNbOperand);
471 
472   return success();
473 }
474 
475 namespace {
476 
477 /// Implementation of the dialect interface that converts operations belonging
478 /// to the OpenACC dialect to LLVM IR.
479 class OpenACCDialectLLVMIRTranslationInterface
480     : public LLVMTranslationDialectInterface {
481 public:
482   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
483 
484   /// Translates the given operation to LLVM IR using the provided IR builder
485   /// and saving the state in `moduleTranslation`.
486   LogicalResult
487   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
488                    LLVM::ModuleTranslation &moduleTranslation) const final;
489 };
490 
491 } // namespace
492 
493 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
494 /// (including OpenACC runtime calls).
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const495 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
496     Operation *op, llvm::IRBuilderBase &builder,
497     LLVM::ModuleTranslation &moduleTranslation) const {
498 
499   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
500       .Case([&](acc::DataOp dataOp) {
501         return convertDataOp(dataOp, builder, moduleTranslation);
502       })
503       .Case([&](acc::EnterDataOp enterDataOp) {
504         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
505                                                          moduleTranslation);
506       })
507       .Case([&](acc::ExitDataOp exitDataOp) {
508         return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
509                                                         moduleTranslation);
510       })
511       .Case([&](acc::UpdateOp updateOp) {
512         return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
513                                                       moduleTranslation);
514       })
515       .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
516         // `yield` and `terminator` can be just omitted. The block structure was
517         // created in the function that handles their parent operation.
518         assert(op->getNumOperands() == 0 &&
519                "unexpected OpenACC terminator with operands");
520         return success();
521       })
522       .Case<acc::CreateOp, acc::CopyinOp, acc::CopyoutOp, acc::DeleteOp,
523             acc::UpdateDeviceOp, acc::GetDevicePtrOp>([](auto op) {
524         // NOP
525         return success();
526       })
527       .Default([&](Operation *op) {
528         return op->emitError("unsupported OpenACC operation: ")
529                << op->getName();
530       });
531 }
532 
registerOpenACCDialectTranslation(DialectRegistry & registry)533 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
534   registry.insert<acc::OpenACCDialect>();
535   registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
536     dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
537   });
538 }
539 
registerOpenACCDialectTranslation(MLIRContext & context)540 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
541   DialectRegistry registry;
542   registerOpenACCDialectTranslation(registry);
543   context.appendDialectRegistry(registry);
544 }
545