xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (revision e0054e984cac39322afa32a6e68fc794f0081f49)
1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Analysis/TopologicalSortUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
22 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
23 #include "mlir/Transforms/RegionUtils.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Frontend/OpenMP/OMPConstants.h"
29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/ReplaceConstant.h"
33 #include "llvm/Support/FileSystem.h"
34 #include "llvm/TargetParser/Triple.h"
35 #include "llvm/Transforms/Utils/ModuleUtils.h"
36 
37 #include <any>
38 #include <cstdint>
39 #include <iterator>
40 #include <numeric>
41 #include <optional>
42 #include <utility>
43 
44 using namespace mlir;
45 
46 namespace {
47 static llvm::omp::ScheduleKind
48 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
49   if (!schedKind.has_value())
50     return llvm::omp::OMP_SCHEDULE_Default;
51   switch (schedKind.value()) {
52   case omp::ClauseScheduleKind::Static:
53     return llvm::omp::OMP_SCHEDULE_Static;
54   case omp::ClauseScheduleKind::Dynamic:
55     return llvm::omp::OMP_SCHEDULE_Dynamic;
56   case omp::ClauseScheduleKind::Guided:
57     return llvm::omp::OMP_SCHEDULE_Guided;
58   case omp::ClauseScheduleKind::Auto:
59     return llvm::omp::OMP_SCHEDULE_Auto;
60   case omp::ClauseScheduleKind::Runtime:
61     return llvm::omp::OMP_SCHEDULE_Runtime;
62   }
63   llvm_unreachable("unhandled schedule clause argument");
64 }
65 
66 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
67 /// insertion points for allocas.
68 class OpenMPAllocaStackFrame
69     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
70 public:
71   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
72 
73   explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
74       : allocaInsertPoint(allocaIP) {}
75   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
76 };
77 
78 /// ModuleTranslation stack frame containing the partial mapping between MLIR
79 /// values and their LLVM IR equivalents.
80 class OpenMPVarMappingStackFrame
81     : public LLVM::ModuleTranslation::StackFrameBase<
82           OpenMPVarMappingStackFrame> {
83 public:
84   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
85 
86   explicit OpenMPVarMappingStackFrame(
87       const DenseMap<Value, llvm::Value *> &mapping)
88       : mapping(mapping) {}
89 
90   DenseMap<Value, llvm::Value *> mapping;
91 };
92 
93 /// Custom error class to signal translation errors that don't need reporting,
94 /// since encountering them will have already triggered relevant error messages.
95 ///
96 /// Its purpose is to serve as the glue between MLIR failures represented as
97 /// \see LogicalResult instances and \see llvm::Error instances used to
98 /// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
99 /// error of the first type is raised, a message is emitted directly (the \see
100 /// LogicalResult itself does not hold any information). If we need to forward
101 /// this error condition as an \see llvm::Error while avoiding triggering some
102 /// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
103 /// class to just signal this situation has happened.
104 ///
105 /// For example, this class should be used to trigger errors from within
106 /// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
107 /// translation of their own regions. This unclutters the error log from
108 /// redundant messages.
109 class PreviouslyReportedError
110     : public llvm::ErrorInfo<PreviouslyReportedError> {
111 public:
112   void log(raw_ostream &) const override {
113     // Do not log anything.
114   }
115 
116   std::error_code convertToErrorCode() const override {
117     llvm_unreachable(
118         "PreviouslyReportedError doesn't support ECError conversion");
119   }
120 
121   // Used by ErrorInfo::classID.
122   static char ID;
123 };
124 
125 char PreviouslyReportedError::ID = 0;
126 
127 } // namespace
128 
129 /// Looks up from the operation from and returns the PrivateClauseOp with
130 /// name symbolName
131 static omp::PrivateClauseOp findPrivatizer(Operation *from,
132                                            SymbolRefAttr symbolName) {
133   omp::PrivateClauseOp privatizer =
134       SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
135                                                                  symbolName);
136   assert(privatizer && "privatizer not found in the symbol table");
137   return privatizer;
138 }
139 
140 /// Check whether translation to LLVM IR for the given operation is currently
141 /// supported. If not, descriptive diagnostics will be emitted to let users know
142 /// this is a not-yet-implemented feature.
143 ///
144 /// \returns success if no unimplemented features are needed to translate the
145 ///          given operation.
146 static LogicalResult checkImplementationStatus(Operation &op) {
147   auto todo = [&op](StringRef clauseName) {
148     return op.emitError() << "not yet implemented: Unhandled clause "
149                           << clauseName << " in " << op.getName()
150                           << " operation";
151   };
152 
153   auto checkAllocate = [&todo](auto op, LogicalResult &result) {
154     if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
155       result = todo("allocate");
156   };
157   auto checkBare = [&todo](auto op, LogicalResult &result) {
158     if (op.getBare())
159       result = todo("ompx_bare");
160   };
161   auto checkDepend = [&todo](auto op, LogicalResult &result) {
162     if (!op.getDependVars().empty() || op.getDependKinds())
163       result = todo("depend");
164   };
165   auto checkDevice = [&todo](auto op, LogicalResult &result) {
166     if (op.getDevice())
167       result = todo("device");
168   };
169   auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) {
170     if (!op.getHasDeviceAddrVars().empty())
171       result = todo("has_device_addr");
172   };
173   auto checkHint = [](auto op, LogicalResult &) {
174     if (op.getHint())
175       op.emitWarning("hint clause discarded");
176   };
177   auto checkHostEval = [](auto op, LogicalResult &result) {
178     // Host evaluated clauses are supported, except for loop bounds.
179     for (BlockArgument arg :
180          cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
181       for (Operation *user : arg.getUsers())
182         if (isa<omp::LoopNestOp>(user))
183           result = op.emitError("not yet implemented: host evaluation of loop "
184                                 "bounds in omp.target operation");
185   };
186   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
187     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
188         op.getInReductionSyms())
189       result = todo("in_reduction");
190   };
191   auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
192     if (!op.getIsDevicePtrVars().empty())
193       result = todo("is_device_ptr");
194   };
195   auto checkLinear = [&todo](auto op, LogicalResult &result) {
196     if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
197       result = todo("linear");
198   };
199   auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
200     if (!op.getNontemporalVars().empty())
201       result = todo("nontemporal");
202   };
203   auto checkNowait = [&todo](auto op, LogicalResult &result) {
204     if (op.getNowait())
205       result = todo("nowait");
206   };
207   auto checkOrder = [&todo](auto op, LogicalResult &result) {
208     if (op.getOrder() || op.getOrderMod())
209       result = todo("order");
210   };
211   auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
212     if (op.getParLevelSimd())
213       result = todo("parallelization-level");
214   };
215   auto checkPriority = [&todo](auto op, LogicalResult &result) {
216     if (op.getPriority())
217       result = todo("priority");
218   };
219   auto checkPrivate = [&todo](auto op, LogicalResult &result) {
220     if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
221       // Privatization clauses are supported, except on some situations, so we
222       // need to check here whether any of these unsupported cases are being
223       // translated.
224       if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) {
225         for (Attribute privatizerNameAttr : *privateSyms) {
226           omp::PrivateClauseOp privatizer = findPrivatizer(
227               op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr));
228 
229           if (privatizer.getDataSharingType() ==
230               omp::DataSharingClauseType::FirstPrivate)
231             result = todo("firstprivate");
232         }
233       }
234     } else {
235       if (!op.getPrivateVars().empty() || op.getPrivateSyms())
236         result = todo("privatization");
237     }
238   };
239   auto checkReduction = [&todo](auto op, LogicalResult &result) {
240     if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
241       if (!op.getReductionVars().empty() || op.getReductionByref() ||
242           op.getReductionSyms())
243         result = todo("reduction");
244     if (op.getReductionMod() &&
245         op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
246       result = todo("reduction with modifier");
247   };
248   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
249     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
250         op.getTaskReductionSyms())
251       result = todo("task_reduction");
252   };
253   auto checkUntied = [&todo](auto op, LogicalResult &result) {
254     if (op.getUntied())
255       result = todo("untied");
256   };
257 
258   LogicalResult result = success();
259   llvm::TypeSwitch<Operation &>(op)
260       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
261       .Case([&](omp::SectionsOp op) {
262         checkAllocate(op, result);
263         checkPrivate(op, result);
264         checkReduction(op, result);
265       })
266       .Case([&](omp::SingleOp op) {
267         checkAllocate(op, result);
268         checkPrivate(op, result);
269       })
270       .Case([&](omp::TeamsOp op) {
271         checkAllocate(op, result);
272         checkPrivate(op, result);
273         checkReduction(op, result);
274       })
275       .Case([&](omp::TaskOp op) {
276         checkAllocate(op, result);
277         checkInReduction(op, result);
278       })
279       .Case([&](omp::TaskgroupOp op) {
280         checkAllocate(op, result);
281         checkTaskReduction(op, result);
282       })
283       .Case([&](omp::TaskwaitOp op) {
284         checkDepend(op, result);
285         checkNowait(op, result);
286       })
287       .Case([&](omp::TaskloopOp op) {
288         // TODO: Add other clauses check
289         checkUntied(op, result);
290         checkPriority(op, result);
291       })
292       .Case([&](omp::WsloopOp op) {
293         checkAllocate(op, result);
294         checkLinear(op, result);
295         checkOrder(op, result);
296         checkReduction(op, result);
297       })
298       .Case([&](omp::ParallelOp op) {
299         checkAllocate(op, result);
300         checkReduction(op, result);
301       })
302       .Case([&](omp::SimdOp op) {
303         checkLinear(op, result);
304         checkNontemporal(op, result);
305         checkReduction(op, result);
306       })
307       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
308             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
309       .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
310           [&](auto op) { checkDepend(op, result); })
311       .Case([&](omp::TargetOp op) {
312         checkAllocate(op, result);
313         checkBare(op, result);
314         checkDevice(op, result);
315         checkHasDeviceAddr(op, result);
316         checkHostEval(op, result);
317         checkInReduction(op, result);
318         checkIsDevicePtr(op, result);
319         checkPrivate(op, result);
320       })
321       .Default([](Operation &) {
322         // Assume all clauses for an operation can be translated unless they are
323         // checked above.
324       });
325   return result;
326 }
327 
328 static LogicalResult handleError(llvm::Error error, Operation &op) {
329   LogicalResult result = success();
330   if (error) {
331     llvm::handleAllErrors(
332         std::move(error),
333         [&](const PreviouslyReportedError &) { result = failure(); },
334         [&](const llvm::ErrorInfoBase &err) {
335           result = op.emitError(err.message());
336         });
337   }
338   return result;
339 }
340 
341 template <typename T>
342 static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
343   if (!result)
344     return handleError(result.takeError(), op);
345 
346   return success();
347 }
348 
349 /// Find the insertion point for allocas given the current insertion point for
350 /// normal operations in the builder.
351 static llvm::OpenMPIRBuilder::InsertPointTy
352 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
353                       const LLVM::ModuleTranslation &moduleTranslation) {
354   // If there is an alloca insertion point on stack, i.e. we are in a nested
355   // operation and a specific point was provided by some surrounding operation,
356   // use it.
357   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
358   WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
359       [&](const OpenMPAllocaStackFrame &frame) {
360         allocaInsertPoint = frame.allocaInsertPoint;
361         return WalkResult::interrupt();
362       });
363   if (walkResult.wasInterrupted())
364     return allocaInsertPoint;
365 
366   // Otherwise, insert to the entry block of the surrounding function.
367   // If the current IRBuilder InsertPoint is the function's entry, it cannot
368   // also be used for alloca insertion which would result in insertion order
369   // confusion. Create a new BasicBlock for the Builder and use the entry block
370   // for the allocs.
371   // TODO: Create a dedicated alloca BasicBlock at function creation such that
372   // we do not need to move the current InertPoint here.
373   if (builder.GetInsertBlock() ==
374       &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
375     assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
376            "Assuming end of basic block");
377     llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
378         builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
379         builder.GetInsertBlock()->getNextNode());
380     builder.CreateBr(entryBB);
381     builder.SetInsertPoint(entryBB);
382   }
383 
384   llvm::BasicBlock &funcEntryBlock =
385       builder.GetInsertBlock()->getParent()->getEntryBlock();
386   return llvm::OpenMPIRBuilder::InsertPointTy(
387       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
388 }
389 
390 /// Converts the given region that appears within an OpenMP dialect operation to
391 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
392 /// region, and a branch from any block with an successor-less OpenMP terminator
393 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
394 /// of the continuation block if provided.
395 static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
396     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
397     LLVM::ModuleTranslation &moduleTranslation,
398     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
399   llvm::BasicBlock *continuationBlock =
400       splitBB(builder, true, "omp.region.cont");
401   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
402 
403   llvm::LLVMContext &llvmContext = builder.getContext();
404   for (Block &bb : region) {
405     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
406         llvmContext, blockName, builder.GetInsertBlock()->getParent(),
407         builder.GetInsertBlock()->getNextNode());
408     moduleTranslation.mapBlock(&bb, llvmBB);
409   }
410 
411   llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
412 
413   // Terminators (namely YieldOp) may be forwarding values to the region that
414   // need to be available in the continuation block. Collect the types of these
415   // operands in preparation of creating PHI nodes.
416   SmallVector<llvm::Type *> continuationBlockPHITypes;
417   bool operandsProcessed = false;
418   unsigned numYields = 0;
419   for (Block &bb : region.getBlocks()) {
420     if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
421       if (!operandsProcessed) {
422         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
423           continuationBlockPHITypes.push_back(
424               moduleTranslation.convertType(yield->getOperand(i).getType()));
425         }
426         operandsProcessed = true;
427       } else {
428         assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
429                "mismatching number of values yielded from the region");
430         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
431           llvm::Type *operandType =
432               moduleTranslation.convertType(yield->getOperand(i).getType());
433           (void)operandType;
434           assert(continuationBlockPHITypes[i] == operandType &&
435                  "values of mismatching types yielded from the region");
436         }
437       }
438       numYields++;
439     }
440   }
441 
442   // Insert PHI nodes in the continuation block for any values forwarded by the
443   // terminators in this region.
444   if (!continuationBlockPHITypes.empty())
445     assert(
446         continuationBlockPHIs &&
447         "expected continuation block PHIs if converted regions yield values");
448   if (continuationBlockPHIs) {
449     llvm::IRBuilderBase::InsertPointGuard guard(builder);
450     continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
451     builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
452     for (llvm::Type *ty : continuationBlockPHITypes)
453       continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
454   }
455 
456   // Convert blocks one by one in topological order to ensure
457   // defs are converted before uses.
458   SetVector<Block *> blocks = getBlocksSortedByDominance(region);
459   for (Block *bb : blocks) {
460     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
461     // Retarget the branch of the entry block to the entry block of the
462     // converted region (regions are single-entry).
463     if (bb->isEntryBlock()) {
464       assert(sourceTerminator->getNumSuccessors() == 1 &&
465              "provided entry block has multiple successors");
466       assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
467              "ContinuationBlock is not the successor of the entry block");
468       sourceTerminator->setSuccessor(0, llvmBB);
469     }
470 
471     llvm::IRBuilderBase::InsertPointGuard guard(builder);
472     if (failed(
473             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
474       return llvm::make_error<PreviouslyReportedError>();
475 
476     // Special handling for `omp.yield` and `omp.terminator` (we may have more
477     // than one): they return the control to the parent OpenMP dialect operation
478     // so replace them with the branch to the continuation block. We handle this
479     // here to avoid relying inter-function communication through the
480     // ModuleTranslation class to set up the correct insertion point. This is
481     // also consistent with MLIR's idiom of handling special region terminators
482     // in the same code that handles the region-owning operation.
483     Operation *terminator = bb->getTerminator();
484     if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
485       builder.CreateBr(continuationBlock);
486 
487       for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
488         (*continuationBlockPHIs)[i]->addIncoming(
489             moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
490     }
491   }
492   // After all blocks have been traversed and values mapped, connect the PHI
493   // nodes to the results of preceding blocks.
494   LLVM::detail::connectPHINodes(region, moduleTranslation);
495 
496   // Remove the blocks and values defined in this region from the mapping since
497   // they are not visible outside of this region. This allows the same region to
498   // be converted several times, that is cloned, without clashes, and slightly
499   // speeds up the lookups.
500   moduleTranslation.forgetMapping(region);
501 
502   return continuationBlock;
503 }
504 
505 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
506 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
507   switch (kind) {
508   case omp::ClauseProcBindKind::Close:
509     return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
510   case omp::ClauseProcBindKind::Master:
511     return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
512   case omp::ClauseProcBindKind::Primary:
513     return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
514   case omp::ClauseProcBindKind::Spread:
515     return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
516   }
517   llvm_unreachable("Unknown ClauseProcBindKind kind");
518 }
519 
520 /// Helper function to map block arguments defined by ignored loop wrappers to
521 /// LLVM values and prevent any uses of those from triggering null pointer
522 /// dereferences.
523 ///
524 /// This must be called after block arguments of parent wrappers have already
525 /// been mapped to LLVM IR values.
526 static LogicalResult
527 convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
528                       LLVM::ModuleTranslation &moduleTranslation) {
529   // Map block arguments directly to the LLVM value associated to the
530   // corresponding operand. This is semantically equivalent to this wrapper not
531   // being present.
532   auto forwardArgs =
533       [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
534                            OperandRange operands) {
535         for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
536           moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
537       };
538 
539   return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
540       .Case([&](omp::SimdOp op) {
541         auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
542         forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
543         forwardArgs(blockArgIface.getReductionBlockArgs(),
544                     op.getReductionVars());
545         op.emitWarning() << "simd information on composite construct discarded";
546         return success();
547       })
548       .Default([&](Operation *op) {
549         return op->emitError() << "cannot ignore nested wrapper";
550       });
551 }
552 
553 /// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
554 /// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
555 /// entry block arguments defined by these operations to outside values.
556 ///
557 /// It must be called after block arguments of \c parentOp have already been
558 /// mapped themselves.
559 static LogicalResult
560 convertIgnoredWrappers(omp::LoopNestOp loopOp,
561                        omp::LoopWrapperInterface parentOp,
562                        LLVM::ModuleTranslation &moduleTranslation) {
563   SmallVector<omp::LoopWrapperInterface> wrappers;
564   loopOp.gatherWrappers(wrappers);
565 
566   // Process wrappers nested inside of `parentOp` from outermost to innermost.
567   for (auto it =
568            std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
569        it != wrappers.rend(); ++it) {
570     if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
571       return failure();
572   }
573 
574   return success();
575 }
576 
577 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
578 static LogicalResult
579 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
580                  LLVM::ModuleTranslation &moduleTranslation) {
581   auto maskedOp = cast<omp::MaskedOp>(opInst);
582   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
583 
584   if (failed(checkImplementationStatus(opInst)))
585     return failure();
586 
587   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
588     // MaskedOp has only one region associated with it.
589     auto &region = maskedOp.getRegion();
590     builder.restoreIP(codeGenIP);
591     return convertOmpOpRegions(region, "omp.masked.region", builder,
592                                moduleTranslation)
593         .takeError();
594   };
595 
596   // TODO: Perform finalization actions for variables. This has to be
597   // called for variables which have destructors/finalizers.
598   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
599 
600   llvm::Value *filterVal = nullptr;
601   if (auto filterVar = maskedOp.getFilteredThreadId()) {
602     filterVal = moduleTranslation.lookupValue(filterVar);
603   } else {
604     llvm::LLVMContext &llvmContext = builder.getContext();
605     filterVal =
606         llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
607   }
608   assert(filterVal != nullptr);
609   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
610   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
611       moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
612                                                          finiCB, filterVal);
613 
614   if (failed(handleError(afterIP, opInst)))
615     return failure();
616 
617   builder.restoreIP(*afterIP);
618   return success();
619 }
620 
621 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
622 static LogicalResult
623 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
624                  LLVM::ModuleTranslation &moduleTranslation) {
625   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
626   auto masterOp = cast<omp::MasterOp>(opInst);
627 
628   if (failed(checkImplementationStatus(opInst)))
629     return failure();
630 
631   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
632     // MasterOp has only one region associated with it.
633     auto &region = masterOp.getRegion();
634     builder.restoreIP(codeGenIP);
635     return convertOmpOpRegions(region, "omp.master.region", builder,
636                                moduleTranslation)
637         .takeError();
638   };
639 
640   // TODO: Perform finalization actions for variables. This has to be
641   // called for variables which have destructors/finalizers.
642   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
643 
644   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
645   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
646       moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
647                                                          finiCB);
648 
649   if (failed(handleError(afterIP, opInst)))
650     return failure();
651 
652   builder.restoreIP(*afterIP);
653   return success();
654 }
655 
656 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
657 static LogicalResult
658 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
659                    LLVM::ModuleTranslation &moduleTranslation) {
660   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
661   auto criticalOp = cast<omp::CriticalOp>(opInst);
662 
663   if (failed(checkImplementationStatus(opInst)))
664     return failure();
665 
666   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
667     // CriticalOp has only one region associated with it.
668     auto &region = cast<omp::CriticalOp>(opInst).getRegion();
669     builder.restoreIP(codeGenIP);
670     return convertOmpOpRegions(region, "omp.critical.region", builder,
671                                moduleTranslation)
672         .takeError();
673   };
674 
675   // TODO: Perform finalization actions for variables. This has to be
676   // called for variables which have destructors/finalizers.
677   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
678 
679   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
680   llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
681   llvm::Constant *hint = nullptr;
682 
683   // If it has a name, it probably has a hint too.
684   if (criticalOp.getNameAttr()) {
685     // The verifiers in OpenMP Dialect guarentee that all the pointers are
686     // non-null
687     auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
688     auto criticalDeclareOp =
689         SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
690                                                                      symbolRef);
691     hint =
692         llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
693                                static_cast<int>(criticalDeclareOp.getHint()));
694   }
695   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
696       moduleTranslation.getOpenMPBuilder()->createCritical(
697           ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
698 
699   if (failed(handleError(afterIP, opInst)))
700     return failure();
701 
702   builder.restoreIP(*afterIP);
703   return success();
704 }
705 
706 /// Populates `privatizations` with privatization declarations used for the
707 /// given op.
708 template <class OP>
709 static void collectPrivatizationDecls(
710     OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
711   std::optional<ArrayAttr> attr = op.getPrivateSyms();
712   if (!attr)
713     return;
714 
715   privatizations.reserve(privatizations.size() + attr->size());
716   for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
717     privatizations.push_back(findPrivatizer(op, symbolRef));
718   }
719 }
720 
721 /// Populates `reductions` with reduction declarations used in the given op.
722 template <typename T>
723 static void
724 collectReductionDecls(T op,
725                       SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
726   std::optional<ArrayAttr> attr = op.getReductionSyms();
727   if (!attr)
728     return;
729 
730   reductions.reserve(reductions.size() + op.getNumReductionVars());
731   for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
732     reductions.push_back(
733         SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
734             op, symbolRef));
735   }
736 }
737 
738 /// Translates the blocks contained in the given region and appends them to at
739 /// the current insertion point of `builder`. The operations of the entry block
740 /// are appended to the current insertion block. If set, `continuationBlockArgs`
741 /// is populated with translated values that correspond to the values
742 /// omp.yield'ed from the region.
743 static LogicalResult inlineConvertOmpRegions(
744     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
745     LLVM::ModuleTranslation &moduleTranslation,
746     SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
747   if (region.empty())
748     return success();
749 
750   // Special case for single-block regions that don't create additional blocks:
751   // insert operations without creating additional blocks.
752   if (llvm::hasSingleElement(region)) {
753     llvm::Instruction *potentialTerminator =
754         builder.GetInsertBlock()->empty() ? nullptr
755                                           : &builder.GetInsertBlock()->back();
756 
757     if (potentialTerminator && potentialTerminator->isTerminator())
758       potentialTerminator->removeFromParent();
759     moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
760 
761     if (failed(moduleTranslation.convertBlock(
762             region.front(), /*ignoreArguments=*/true, builder)))
763       return failure();
764 
765     // The continuation arguments are simply the translated terminator operands.
766     if (continuationBlockArgs)
767       llvm::append_range(
768           *continuationBlockArgs,
769           moduleTranslation.lookupValues(region.front().back().getOperands()));
770 
771     // Drop the mapping that is no longer necessary so that the same region can
772     // be processed multiple times.
773     moduleTranslation.forgetMapping(region);
774 
775     if (potentialTerminator && potentialTerminator->isTerminator()) {
776       llvm::BasicBlock *block = builder.GetInsertBlock();
777       if (block->empty()) {
778         // this can happen for really simple reduction init regions e.g.
779         // %0 = llvm.mlir.constant(0 : i32) : i32
780         // omp.yield(%0 : i32)
781         // because the llvm.mlir.constant (MLIR op) isn't converted into any
782         // llvm op
783         potentialTerminator->insertInto(block, block->begin());
784       } else {
785         potentialTerminator->insertAfter(&block->back());
786       }
787     }
788 
789     return success();
790   }
791 
792   SmallVector<llvm::PHINode *> phis;
793   llvm::Expected<llvm::BasicBlock *> continuationBlock =
794       convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
795 
796   if (failed(handleError(continuationBlock, *region.getParentOp())))
797     return failure();
798 
799   if (continuationBlockArgs)
800     llvm::append_range(*continuationBlockArgs, phis);
801   builder.SetInsertPoint(*continuationBlock,
802                          (*continuationBlock)->getFirstInsertionPt());
803   return success();
804 }
805 
806 namespace {
807 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
808 /// store lambdas with capture.
809 using OwningReductionGen =
810     std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
811         llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
812         llvm::Value *&)>;
813 using OwningAtomicReductionGen =
814     std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
815         llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
816         llvm::Value *)>;
817 } // namespace
818 
819 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
820 /// reduction declaration. The generator uses `builder` but ignores its
821 /// insertion point.
822 static OwningReductionGen
823 makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
824                  LLVM::ModuleTranslation &moduleTranslation) {
825   // The lambda is mutable because we need access to non-const methods of decl
826   // (which aren't actually mutating it), and we must capture decl by-value to
827   // avoid the dangling reference after the parent function returns.
828   OwningReductionGen gen =
829       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
830                 llvm::Value *lhs, llvm::Value *rhs,
831                 llvm::Value *&result) mutable
832       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
833     moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
834     moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
835     builder.restoreIP(insertPoint);
836     SmallVector<llvm::Value *> phis;
837     if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
838                                        "omp.reduction.nonatomic.body", builder,
839                                        moduleTranslation, &phis)))
840       return llvm::createStringError(
841           "failed to inline `combiner` region of `omp.declare_reduction`");
842     assert(phis.size() == 1);
843     result = phis[0];
844     return builder.saveIP();
845   };
846   return gen;
847 }
848 
849 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
850 /// given reduction declaration. The generator uses `builder` but ignores its
851 /// insertion point. Returns null if there is no atomic region available in the
852 /// reduction declaration.
853 static OwningAtomicReductionGen
854 makeAtomicReductionGen(omp::DeclareReductionOp decl,
855                        llvm::IRBuilderBase &builder,
856                        LLVM::ModuleTranslation &moduleTranslation) {
857   if (decl.getAtomicReductionRegion().empty())
858     return OwningAtomicReductionGen();
859 
860   // The lambda is mutable because we need access to non-const methods of decl
861   // (which aren't actually mutating it), and we must capture decl by-value to
862   // avoid the dangling reference after the parent function returns.
863   OwningAtomicReductionGen atomicGen =
864       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
865                 llvm::Value *lhs, llvm::Value *rhs) mutable
866       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
867     moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
868     moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
869     builder.restoreIP(insertPoint);
870     SmallVector<llvm::Value *> phis;
871     if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
872                                        "omp.reduction.atomic.body", builder,
873                                        moduleTranslation, &phis)))
874       return llvm::createStringError(
875           "failed to inline `atomic` region of `omp.declare_reduction`");
876     assert(phis.empty());
877     return builder.saveIP();
878   };
879   return atomicGen;
880 }
881 
882 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
883 static LogicalResult
884 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
885                   LLVM::ModuleTranslation &moduleTranslation) {
886   auto orderedOp = cast<omp::OrderedOp>(opInst);
887 
888   if (failed(checkImplementationStatus(opInst)))
889     return failure();
890 
891   omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
892   bool isDependSource = dependType == omp::ClauseDepend::dependsource;
893   unsigned numLoops = *orderedOp.getDoacrossNumLoops();
894   SmallVector<llvm::Value *> vecValues =
895       moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
896 
897   size_t indexVecValues = 0;
898   while (indexVecValues < vecValues.size()) {
899     SmallVector<llvm::Value *> storeValues;
900     storeValues.reserve(numLoops);
901     for (unsigned i = 0; i < numLoops; i++) {
902       storeValues.push_back(vecValues[indexVecValues]);
903       indexVecValues++;
904     }
905     llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
906         findAllocaInsertPoint(builder, moduleTranslation);
907     llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
908     builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
909         ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
910   }
911   return success();
912 }
913 
914 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
915 /// OpenMPIRBuilder.
916 static LogicalResult
917 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
918                         LLVM::ModuleTranslation &moduleTranslation) {
919   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
920   auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
921 
922   if (failed(checkImplementationStatus(opInst)))
923     return failure();
924 
925   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
926     // OrderedOp has only one region associated with it.
927     auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
928     builder.restoreIP(codeGenIP);
929     return convertOmpOpRegions(region, "omp.ordered.region", builder,
930                                moduleTranslation)
931         .takeError();
932   };
933 
934   // TODO: Perform finalization actions for variables. This has to be
935   // called for variables which have destructors/finalizers.
936   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
937 
938   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
939   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
940       moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
941           ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
942 
943   if (failed(handleError(afterIP, opInst)))
944     return failure();
945 
946   builder.restoreIP(*afterIP);
947   return success();
948 }
949 
950 namespace {
951 /// Contains the arguments for an LLVM store operation
952 struct DeferredStore {
953   DeferredStore(llvm::Value *value, llvm::Value *address)
954       : value(value), address(address) {}
955 
956   llvm::Value *value;
957   llvm::Value *address;
958 };
959 } // namespace
960 
961 /// Allocate space for privatized reduction variables.
962 /// `deferredStores` contains information to create store operations which needs
963 /// to be inserted after all allocas
964 template <typename T>
965 static LogicalResult
966 allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
967                    llvm::IRBuilderBase &builder,
968                    LLVM::ModuleTranslation &moduleTranslation,
969                    const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
970                    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
971                    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
972                    DenseMap<Value, llvm::Value *> &reductionVariableMap,
973                    SmallVectorImpl<DeferredStore> &deferredStores,
974                    llvm::ArrayRef<bool> isByRefs) {
975   llvm::IRBuilderBase::InsertPointGuard guard(builder);
976   builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
977 
978   // delay creating stores until after all allocas
979   deferredStores.reserve(loop.getNumReductionVars());
980 
981   for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
982     Region &allocRegion = reductionDecls[i].getAllocRegion();
983     if (isByRefs[i]) {
984       if (allocRegion.empty())
985         continue;
986 
987       SmallVector<llvm::Value *, 1> phis;
988       if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
989                                          builder, moduleTranslation, &phis)))
990         return loop.emitError(
991             "failed to inline `alloc` region of `omp.declare_reduction`");
992 
993       assert(phis.size() == 1 && "expected one allocation to be yielded");
994       builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
995 
996       // Allocate reduction variable (which is a pointer to the real reduction
997       // variable allocated in the inlined region)
998       llvm::Value *var = builder.CreateAlloca(
999           moduleTranslation.convertType(reductionDecls[i].getType()));
1000       deferredStores.emplace_back(phis[0], var);
1001 
1002       privateReductionVariables[i] = var;
1003       moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1004       reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
1005     } else {
1006       assert(allocRegion.empty() &&
1007              "allocaction is implicit for by-val reduction");
1008       llvm::Value *var = builder.CreateAlloca(
1009           moduleTranslation.convertType(reductionDecls[i].getType()));
1010       moduleTranslation.mapValue(reductionArgs[i], var);
1011       privateReductionVariables[i] = var;
1012       reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
1013     }
1014   }
1015 
1016   return success();
1017 }
1018 
1019 /// Map input arguments to reduction initialization region
1020 template <typename T>
1021 static void
1022 mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
1023                       SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1024                       DenseMap<Value, llvm::Value *> &reductionVariableMap,
1025                       unsigned i) {
1026   // map input argument to the initialization region
1027   mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1028   Region &initializerRegion = reduction.getInitializerRegion();
1029   Block &entry = initializerRegion.front();
1030 
1031   mlir::Value mlirSource = loop.getReductionVars()[i];
1032   llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1033   assert(llvmSource && "lookup reduction var");
1034   moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
1035 
1036   if (entry.getNumArguments() > 1) {
1037     llvm::Value *allocation =
1038         reductionVariableMap.lookup(loop.getReductionVars()[i]);
1039     moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1040   }
1041 }
1042 
1043 /// Inline reductions' `init` regions. This functions assumes that the
1044 /// `builder`'s insertion point is where the user wants the `init` regions to be
1045 /// inlined; i.e. it does not try to find a proper insertion location for the
1046 /// `init` regions. It also leaves the `builder's insertions point in a state
1047 /// where the user can continue the code-gen directly afterwards.
1048 template <typename OP>
1049 static LogicalResult
1050 initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1051                   llvm::IRBuilderBase &builder,
1052                   LLVM::ModuleTranslation &moduleTranslation,
1053                   llvm::BasicBlock *latestAllocaBlock,
1054                   SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1055                   SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1056                   DenseMap<Value, llvm::Value *> &reductionVariableMap,
1057                   llvm::ArrayRef<bool> isByRef,
1058                   SmallVectorImpl<DeferredStore> &deferredStores) {
1059   if (op.getNumReductionVars() == 0)
1060     return success();
1061 
1062   llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1063   auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1064       latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1065   builder.restoreIP(allocaIP);
1066   SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1067 
1068   for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1069     if (isByRef[i]) {
1070       if (!reductionDecls[i].getAllocRegion().empty())
1071         continue;
1072 
1073       // TODO: remove after all users of by-ref are updated to use the alloc
1074       // region: Allocate reduction variable (which is a pointer to the real
1075       // reduciton variable allocated in the inlined region)
1076       byRefVars[i] = builder.CreateAlloca(
1077           moduleTranslation.convertType(reductionDecls[i].getType()));
1078     }
1079   }
1080 
1081   if (initBlock->empty() || initBlock->getTerminator() == nullptr)
1082     builder.SetInsertPoint(initBlock);
1083   else
1084     builder.SetInsertPoint(initBlock->getTerminator());
1085 
1086   // store result of the alloc region to the allocated pointer to the real
1087   // reduction variable
1088   for (auto [data, addr] : deferredStores)
1089     builder.CreateStore(data, addr);
1090 
1091   // Before the loop, store the initial values of reductions into reduction
1092   // variables. Although this could be done after allocas, we don't want to mess
1093   // up with the alloca insertion point.
1094   for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1095     SmallVector<llvm::Value *, 1> phis;
1096 
1097     // map block argument to initializer region
1098     mapInitializationArgs(op, moduleTranslation, reductionDecls,
1099                           reductionVariableMap, i);
1100 
1101     if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1102                                        "omp.reduction.neutral", builder,
1103                                        moduleTranslation, &phis)))
1104       return failure();
1105 
1106     assert(phis.size() == 1 && "expected one value to be yielded from the "
1107                                "reduction neutral element declaration region");
1108 
1109     if (builder.GetInsertBlock()->empty() ||
1110         builder.GetInsertBlock()->getTerminator() == nullptr)
1111       builder.SetInsertPoint(builder.GetInsertBlock());
1112     else
1113       builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1114 
1115     if (isByRef[i]) {
1116       if (!reductionDecls[i].getAllocRegion().empty())
1117         // done in allocReductionVars
1118         continue;
1119 
1120       // TODO: this path can be removed once all users of by-ref are updated to
1121       // use an alloc region
1122 
1123       // Store the result of the inlined region to the allocated reduction var
1124       // ptr
1125       builder.CreateStore(phis[0], byRefVars[i]);
1126 
1127       privateReductionVariables[i] = byRefVars[i];
1128       moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1129       reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1130     } else {
1131       // for by-ref case the store is inside of the reduction region
1132       builder.CreateStore(phis[0], privateReductionVariables[i]);
1133       // the rest was handled in allocByValReductionVars
1134     }
1135 
1136     // forget the mapping for the initializer region because we might need a
1137     // different mapping if this reduction declaration is re-used for a
1138     // different variable
1139     moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1140   }
1141 
1142   return success();
1143 }
1144 
1145 /// Collect reduction info
1146 template <typename T>
1147 static void collectReductionInfo(
1148     T loop, llvm::IRBuilderBase &builder,
1149     LLVM::ModuleTranslation &moduleTranslation,
1150     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1151     SmallVectorImpl<OwningReductionGen> &owningReductionGens,
1152     SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1153     const ArrayRef<llvm::Value *> privateReductionVariables,
1154     SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
1155   unsigned numReductions = loop.getNumReductionVars();
1156 
1157   for (unsigned i = 0; i < numReductions; ++i) {
1158     owningReductionGens.push_back(
1159         makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1160     owningAtomicReductionGens.push_back(
1161         makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1162   }
1163 
1164   // Collect the reduction information.
1165   reductionInfos.reserve(numReductions);
1166   for (unsigned i = 0; i < numReductions; ++i) {
1167     llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1168     if (owningAtomicReductionGens[i])
1169       atomicGen = owningAtomicReductionGens[i];
1170     llvm::Value *variable =
1171         moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1172     reductionInfos.push_back(
1173         {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1174          privateReductionVariables[i],
1175          /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1176          owningReductionGens[i],
1177          /*ReductionGenClang=*/nullptr, atomicGen});
1178   }
1179 }
1180 
1181 /// handling of DeclareReductionOp's cleanup region
1182 static LogicalResult
1183 inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
1184                        llvm::ArrayRef<llvm::Value *> privateVariables,
1185                        LLVM::ModuleTranslation &moduleTranslation,
1186                        llvm::IRBuilderBase &builder, StringRef regionName,
1187                        bool shouldLoadCleanupRegionArg = true) {
1188   for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1189     if (cleanupRegion->empty())
1190       continue;
1191 
1192     // map the argument to the cleanup region
1193     Block &entry = cleanupRegion->front();
1194 
1195     llvm::Instruction *potentialTerminator =
1196         builder.GetInsertBlock()->empty() ? nullptr
1197                                           : &builder.GetInsertBlock()->back();
1198     if (potentialTerminator && potentialTerminator->isTerminator())
1199       builder.SetInsertPoint(potentialTerminator);
1200     llvm::Value *privateVarValue =
1201         shouldLoadCleanupRegionArg
1202             ? builder.CreateLoad(
1203                   moduleTranslation.convertType(entry.getArgument(0).getType()),
1204                   privateVariables[i])
1205             : privateVariables[i];
1206 
1207     moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1208 
1209     if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1210                                        moduleTranslation)))
1211       return failure();
1212 
1213     // clear block argument mapping in case it needs to be re-created with a
1214     // different source for another use of the same reduction decl
1215     moduleTranslation.forgetMapping(*cleanupRegion);
1216   }
1217   return success();
1218 }
1219 
1220 // TODO: not used by ParallelOp
1221 template <class OP>
1222 static LogicalResult createReductionsAndCleanup(
1223     OP op, llvm::IRBuilderBase &builder,
1224     LLVM::ModuleTranslation &moduleTranslation,
1225     llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1226     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1227     ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
1228   // Process the reductions if required.
1229   if (op.getNumReductionVars() == 0)
1230     return success();
1231 
1232   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1233 
1234   // Create the reduction generators. We need to own them here because
1235   // ReductionInfo only accepts references to the generators.
1236   SmallVector<OwningReductionGen> owningReductionGens;
1237   SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1238   SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1239   collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1240                        owningReductionGens, owningAtomicReductionGens,
1241                        privateReductionVariables, reductionInfos);
1242 
1243   // The call to createReductions below expects the block to have a
1244   // terminator. Create an unreachable instruction to serve as terminator
1245   // and remove it later.
1246   llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1247   builder.SetInsertPoint(tempTerminator);
1248   llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1249       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1250                                    isByRef, op.getNowait());
1251 
1252   if (failed(handleError(contInsertPoint, *op)))
1253     return failure();
1254 
1255   if (!contInsertPoint->getBlock())
1256     return op->emitOpError() << "failed to convert reductions";
1257 
1258   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1259       ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1260 
1261   if (failed(handleError(afterIP, *op)))
1262     return failure();
1263 
1264   tempTerminator->eraseFromParent();
1265   builder.restoreIP(*afterIP);
1266 
1267   // after the construct, deallocate private reduction variables
1268   SmallVector<Region *> reductionRegions;
1269   llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1270                   [](omp::DeclareReductionOp reductionDecl) {
1271                     return &reductionDecl.getCleanupRegion();
1272                   });
1273   return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1274                                 moduleTranslation, builder,
1275                                 "omp.reduction.cleanup");
1276   return success();
1277 }
1278 
1279 static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1280   if (!attr)
1281     return {};
1282   return *attr;
1283 }
1284 
1285 // TODO: not used by omp.parallel
1286 template <typename OP>
1287 static LogicalResult allocAndInitializeReductionVars(
1288     OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1289     LLVM::ModuleTranslation &moduleTranslation,
1290     llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1291     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1292     SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1293     DenseMap<Value, llvm::Value *> &reductionVariableMap,
1294     llvm::ArrayRef<bool> isByRef) {
1295   if (op.getNumReductionVars() == 0)
1296     return success();
1297 
1298   SmallVector<DeferredStore> deferredStores;
1299 
1300   if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1301                                 allocaIP, reductionDecls,
1302                                 privateReductionVariables, reductionVariableMap,
1303                                 deferredStores, isByRef)))
1304     return failure();
1305 
1306   return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1307                            allocaIP.getBlock(), reductionDecls,
1308                            privateReductionVariables, reductionVariableMap,
1309                            isByRef, deferredStores);
1310 }
1311 
1312 /// Return the llvm::Value * corresponding to the `privateVar` that
1313 /// is being privatized. It isn't always as simple as looking up
1314 /// moduleTranslation with privateVar. For instance, in case of
1315 /// an allocatable, the descriptor for the allocatable is privatized.
1316 /// This descriptor is mapped using an MapInfoOp. So, this function
1317 /// will return a pointer to the llvm::Value corresponding to the
1318 /// block argument for the mapped descriptor.
1319 static llvm::Value *
1320 findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1321                     LLVM::ModuleTranslation &moduleTranslation,
1322                     llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1323   if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1324     return moduleTranslation.lookupValue(privateVar);
1325 
1326   Value blockArg = (*mappedPrivateVars)[privateVar];
1327   Type privVarType = privateVar.getType();
1328   Type blockArgType = blockArg.getType();
1329   assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1330          "A block argument corresponding to a mapped var should have "
1331          "!llvm.ptr type");
1332 
1333   if (privVarType == blockArgType)
1334     return moduleTranslation.lookupValue(blockArg);
1335 
1336   // This typically happens when the privatized type is lowered from
1337   // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1338   // struct/pair is passed by value. But, mapped values are passed only as
1339   // pointers, so before we privatize, we must load the pointer.
1340   if (!isa<LLVM::LLVMPointerType>(privVarType))
1341     return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1342                               moduleTranslation.lookupValue(blockArg));
1343 
1344   return moduleTranslation.lookupValue(privateVar);
1345 }
1346 
1347 /// Allocate delayed private variables. Returns the basic block which comes
1348 /// after all of these allocations. llvm::Value * for each of these private
1349 /// variables are populated in llvmPrivateVars.
1350 static llvm::Expected<llvm::BasicBlock *>
1351 allocatePrivateVars(llvm::IRBuilderBase &builder,
1352                     LLVM::ModuleTranslation &moduleTranslation,
1353                     MutableArrayRef<BlockArgument> privateBlockArgs,
1354                     MutableArrayRef<omp::PrivateClauseOp> privateDecls,
1355                     MutableArrayRef<mlir::Value> mlirPrivateVars,
1356                     llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1357                     const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1358                     llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1359   // Allocate private vars
1360   llvm::BranchInst *allocaTerminator =
1361       llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
1362   splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1363                                                allocaTerminator->getIterator()),
1364           true, "omp.region.after_alloca");
1365 
1366   llvm::IRBuilderBase::InsertPointGuard guard(builder);
1367   // Update the allocaTerminator in case the alloca block was split above.
1368   allocaTerminator =
1369       llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
1370   builder.SetInsertPoint(allocaTerminator);
1371   assert(allocaTerminator->getNumSuccessors() == 1 &&
1372          "This is an unconditional branch created by OpenMPIRBuilder");
1373 
1374   llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1375 
1376   // FIXME: Some of the allocation regions do more than just allocating.
1377   // They read from their block argument (amongst other non-alloca things).
1378   // When OpenMPIRBuilder outlines the parallel region into a different
1379   // function it places the loads for live in-values (such as these block
1380   // arguments) at the end of the entry block (because the entry block is
1381   // assumed to contain only allocas). Therefore, if we put these complicated
1382   // alloc blocks in the entry block, these will not dominate the availability
1383   // of the live-in values they are using. Fix this by adding a latealloc
1384   // block after the entry block to put these in (this also helps to avoid
1385   // mixing non-alloca code with allocas).
1386   // Alloc regions which do not use the block argument can still be placed in
1387   // the entry block (therefore keeping the allocas together).
1388   llvm::BasicBlock *privAllocBlock = nullptr;
1389   if (!privateBlockArgs.empty())
1390     privAllocBlock = splitBB(builder, true, "omp.private.latealloc");
1391   for (auto [privDecl, mlirPrivVar, blockArg] :
1392        llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1393     Region &allocRegion = privDecl.getAllocRegion();
1394 
1395     // map allocation region block argument
1396     llvm::Value *nonPrivateVar = findAssociatedValue(
1397         mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1398     assert(nonPrivateVar);
1399     moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
1400 
1401     // in-place convert the private allocation region
1402     SmallVector<llvm::Value *, 1> phis;
1403     if (privDecl.getAllocMoldArg().getUses().empty()) {
1404       // TODO this should use
1405       // allocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca() so it goes before
1406       // the code for fetching the thread id. Not doing this for now to avoid
1407       // test churn.
1408       builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1409     } else {
1410       builder.SetInsertPoint(privAllocBlock->getTerminator());
1411     }
1412 
1413     if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc",
1414                                        builder, moduleTranslation, &phis)))
1415       return llvm::createStringError(
1416           "failed to inline `alloc` region of `omp.private`");
1417 
1418     assert(phis.size() == 1 && "expected one allocation to be yielded");
1419 
1420     moduleTranslation.mapValue(blockArg, phis[0]);
1421     llvmPrivateVars.push_back(phis[0]);
1422 
1423     // clear alloc region block argument mapping in case it needs to be
1424     // re-created with a different source for another use of the same
1425     // reduction decl
1426     moduleTranslation.forgetMapping(allocRegion);
1427   }
1428   return afterAllocas;
1429 }
1430 
1431 static LogicalResult
1432 initFirstPrivateVars(llvm::IRBuilderBase &builder,
1433                      LLVM::ModuleTranslation &moduleTranslation,
1434                      SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1435                      SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1436                      SmallVectorImpl<omp::PrivateClauseOp> &privateDecls,
1437                      llvm::BasicBlock *afterAllocas) {
1438   llvm::IRBuilderBase::InsertPointGuard guard(builder);
1439   // Apply copy region for firstprivate.
1440   bool needsFirstprivate =
1441       llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1442         return privOp.getDataSharingType() ==
1443                omp::DataSharingClauseType::FirstPrivate;
1444       });
1445 
1446   if (!needsFirstprivate)
1447     return success();
1448 
1449   assert(afterAllocas->getSinglePredecessor());
1450 
1451   // Find the end of the allocation blocks
1452   builder.SetInsertPoint(afterAllocas->getSinglePredecessor()->getTerminator());
1453   llvm::BasicBlock *copyBlock =
1454       splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1455   builder.SetInsertPoint(copyBlock->getFirstNonPHIOrDbgOrAlloca());
1456 
1457   for (auto [decl, mlirVar, llvmVar] :
1458        llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1459     if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1460       continue;
1461 
1462     // copyRegion implements `lhs = rhs`
1463     Region &copyRegion = decl.getCopyRegion();
1464 
1465     // map copyRegion rhs arg
1466     llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirVar);
1467     assert(nonPrivateVar);
1468     moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1469 
1470     // map copyRegion lhs arg
1471     moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1472 
1473     // in-place convert copy region
1474     builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1475     if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1476                                        moduleTranslation)))
1477       return decl.emitError("failed to inline `copy` region of `omp.private`");
1478 
1479     // ignore unused value yielded from copy region
1480 
1481     // clear copy region block argument mapping in case it needs to be
1482     // re-created with different sources for reuse of the same reduction
1483     // decl
1484     moduleTranslation.forgetMapping(copyRegion);
1485   }
1486 
1487   return success();
1488 }
1489 
1490 static LogicalResult
1491 cleanupPrivateVars(llvm::IRBuilderBase &builder,
1492                    LLVM::ModuleTranslation &moduleTranslation, Location loc,
1493                    SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1494                    SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
1495   // private variable deallocation
1496   SmallVector<Region *> privateCleanupRegions;
1497   llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1498                   [](omp::PrivateClauseOp privatizer) {
1499                     return &privatizer.getDeallocRegion();
1500                   });
1501 
1502   if (failed(inlineOmpRegionCleanup(
1503           privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1504           "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1505     return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1506                                 "`omp.private` op in");
1507 
1508   return success();
1509 }
1510 
1511 static LogicalResult
1512 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1513                    LLVM::ModuleTranslation &moduleTranslation) {
1514   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1515   using StorableBodyGenCallbackTy =
1516       llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1517 
1518   auto sectionsOp = cast<omp::SectionsOp>(opInst);
1519 
1520   if (failed(checkImplementationStatus(opInst)))
1521     return failure();
1522 
1523   llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1524   assert(isByRef.size() == sectionsOp.getNumReductionVars());
1525 
1526   SmallVector<omp::DeclareReductionOp> reductionDecls;
1527   collectReductionDecls(sectionsOp, reductionDecls);
1528   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1529       findAllocaInsertPoint(builder, moduleTranslation);
1530 
1531   SmallVector<llvm::Value *> privateReductionVariables(
1532       sectionsOp.getNumReductionVars());
1533   DenseMap<Value, llvm::Value *> reductionVariableMap;
1534 
1535   MutableArrayRef<BlockArgument> reductionArgs =
1536       cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1537 
1538   if (failed(allocAndInitializeReductionVars(
1539           sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1540           reductionDecls, privateReductionVariables, reductionVariableMap,
1541           isByRef)))
1542     return failure();
1543 
1544   // Store the mapping between reduction variables and their private copies on
1545   // ModuleTranslation stack. It can be then recovered when translating
1546   // omp.reduce operations in a separate call.
1547   LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1548       moduleTranslation, reductionVariableMap);
1549 
1550   SmallVector<StorableBodyGenCallbackTy> sectionCBs;
1551 
1552   for (Operation &op : *sectionsOp.getRegion().begin()) {
1553     auto sectionOp = dyn_cast<omp::SectionOp>(op);
1554     if (!sectionOp) // omp.terminator
1555       continue;
1556 
1557     Region &region = sectionOp.getRegion();
1558     auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1559                          InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1560       builder.restoreIP(codeGenIP);
1561 
1562       // map the omp.section reduction block argument to the omp.sections block
1563       // arguments
1564       // TODO: this assumes that the only block arguments are reduction
1565       // variables
1566       assert(region.getNumArguments() ==
1567              sectionsOp.getRegion().getNumArguments());
1568       for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1569                sectionsOp.getRegion().getArguments(), region.getArguments())) {
1570         llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1571         assert(llvmVal);
1572         moduleTranslation.mapValue(sectionArg, llvmVal);
1573       }
1574 
1575       return convertOmpOpRegions(region, "omp.section.region", builder,
1576                                  moduleTranslation)
1577           .takeError();
1578     };
1579     sectionCBs.push_back(sectionCB);
1580   }
1581 
1582   // No sections within omp.sections operation - skip generation. This situation
1583   // is only possible if there is only a terminator operation inside the
1584   // sections operation
1585   if (sectionCBs.empty())
1586     return success();
1587 
1588   assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1589 
1590   // TODO: Perform appropriate actions according to the data-sharing
1591   // attribute (shared, private, firstprivate, ...) of variables.
1592   // Currently defaults to shared.
1593   auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1594                     llvm::Value &vPtr, llvm::Value *&replacementValue)
1595       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1596     replacementValue = &vPtr;
1597     return codeGenIP;
1598   };
1599 
1600   // TODO: Perform finalization actions for variables. This has to be
1601   // called for variables which have destructors/finalizers.
1602   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1603 
1604   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1605   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1606   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1607       moduleTranslation.getOpenMPBuilder()->createSections(
1608           ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
1609           sectionsOp.getNowait());
1610 
1611   if (failed(handleError(afterIP, opInst)))
1612     return failure();
1613 
1614   builder.restoreIP(*afterIP);
1615 
1616   // Process the reductions if required.
1617   return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
1618                                     allocaIP, reductionDecls,
1619                                     privateReductionVariables, isByRef);
1620 }
1621 
1622 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1623 static LogicalResult
1624 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1625                  LLVM::ModuleTranslation &moduleTranslation) {
1626   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1627   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1628 
1629   if (failed(checkImplementationStatus(*singleOp)))
1630     return failure();
1631 
1632   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1633     builder.restoreIP(codegenIP);
1634     return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1635                                builder, moduleTranslation)
1636         .takeError();
1637   };
1638   auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1639 
1640   // Handle copyprivate
1641   Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1642   std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1643   llvm::SmallVector<llvm::Value *> llvmCPVars;
1644   llvm::SmallVector<llvm::Function *> llvmCPFuncs;
1645   for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1646     llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1647     auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1648         singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1649     llvmCPFuncs.push_back(
1650         moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1651   }
1652 
1653   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1654       moduleTranslation.getOpenMPBuilder()->createSingle(
1655           ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1656           llvmCPFuncs);
1657 
1658   if (failed(handleError(afterIP, *singleOp)))
1659     return failure();
1660 
1661   builder.restoreIP(*afterIP);
1662   return success();
1663 }
1664 
1665 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
1666 static LogicalResult
1667 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
1668                 LLVM::ModuleTranslation &moduleTranslation) {
1669   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1670   if (failed(checkImplementationStatus(*op)))
1671     return failure();
1672 
1673   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1674     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1675         moduleTranslation, allocaIP);
1676     builder.restoreIP(codegenIP);
1677     return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
1678                                moduleTranslation)
1679         .takeError();
1680   };
1681 
1682   llvm::Value *numTeamsLower = nullptr;
1683   if (Value numTeamsLowerVar = op.getNumTeamsLower())
1684     numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
1685 
1686   llvm::Value *numTeamsUpper = nullptr;
1687   if (Value numTeamsUpperVar = op.getNumTeamsUpper())
1688     numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
1689 
1690   llvm::Value *threadLimit = nullptr;
1691   if (Value threadLimitVar = op.getThreadLimit())
1692     threadLimit = moduleTranslation.lookupValue(threadLimitVar);
1693 
1694   llvm::Value *ifExpr = nullptr;
1695   if (Value ifVar = op.getIfExpr())
1696     ifExpr = moduleTranslation.lookupValue(ifVar);
1697 
1698   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1699   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1700       moduleTranslation.getOpenMPBuilder()->createTeams(
1701           ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1702 
1703   if (failed(handleError(afterIP, *op)))
1704     return failure();
1705 
1706   builder.restoreIP(*afterIP);
1707   return success();
1708 }
1709 
1710 static void
1711 buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
1712                 LLVM::ModuleTranslation &moduleTranslation,
1713                 SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) {
1714   if (dependVars.empty())
1715     return;
1716   for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
1717     llvm::omp::RTLDependenceKindTy type;
1718     switch (
1719         cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1720     case mlir::omp::ClauseTaskDepend::taskdependin:
1721       type = llvm::omp::RTLDependenceKindTy::DepIn;
1722       break;
1723     // The OpenMP runtime requires that the codegen for 'depend' clause for
1724     // 'out' dependency kind must be the same as codegen for 'depend' clause
1725     // with 'inout' dependency.
1726     case mlir::omp::ClauseTaskDepend::taskdependout:
1727     case mlir::omp::ClauseTaskDepend::taskdependinout:
1728       type = llvm::omp::RTLDependenceKindTy::DepInOut;
1729       break;
1730     case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
1731       type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
1732       break;
1733     case mlir::omp::ClauseTaskDepend::taskdependinoutset:
1734       type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
1735       break;
1736     };
1737     llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
1738     llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1739     dds.emplace_back(dd);
1740   }
1741 }
1742 
1743 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
1744 static LogicalResult
1745 convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
1746                  LLVM::ModuleTranslation &moduleTranslation) {
1747   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1748   if (failed(checkImplementationStatus(*taskOp)))
1749     return failure();
1750 
1751   // Collect delayed privatisation declarations
1752   MutableArrayRef<BlockArgument> privateBlockArgs =
1753       cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
1754   SmallVector<mlir::Value> mlirPrivateVars;
1755   SmallVector<llvm::Value *> llvmPrivateVars;
1756   SmallVector<omp::PrivateClauseOp> privateDecls;
1757   mlirPrivateVars.reserve(privateBlockArgs.size());
1758   llvmPrivateVars.reserve(privateBlockArgs.size());
1759   collectPrivatizationDecls(taskOp, privateDecls);
1760   for (mlir::Value privateVar : taskOp.getPrivateVars())
1761     mlirPrivateVars.push_back(privateVar);
1762 
1763   auto bodyCB = [&](InsertPointTy allocaIP,
1764                     InsertPointTy codegenIP) -> llvm::Error {
1765     // Save the alloca insertion point on ModuleTranslation stack for use in
1766     // nested regions.
1767     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1768         moduleTranslation, allocaIP);
1769 
1770     llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
1771         builder, moduleTranslation, privateBlockArgs, privateDecls,
1772         mlirPrivateVars, llvmPrivateVars, allocaIP);
1773     if (handleError(afterAllocas, *taskOp).failed())
1774       return llvm::make_error<PreviouslyReportedError>();
1775 
1776     if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1777                                     llvmPrivateVars, privateDecls,
1778                                     afterAllocas.get())))
1779       return llvm::make_error<PreviouslyReportedError>();
1780 
1781     // translate the body of the task:
1782     builder.restoreIP(codegenIP);
1783     auto continuationBlockOrError = convertOmpOpRegions(
1784         taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
1785     if (failed(handleError(continuationBlockOrError, *taskOp)))
1786       return llvm::make_error<PreviouslyReportedError>();
1787 
1788     builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
1789 
1790     if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
1791                                   llvmPrivateVars, privateDecls)))
1792       return llvm::make_error<PreviouslyReportedError>();
1793 
1794     return llvm::Error::success();
1795   };
1796 
1797   SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
1798   buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
1799                   moduleTranslation, dds);
1800 
1801   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1802       findAllocaInsertPoint(builder, moduleTranslation);
1803   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1804   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1805       moduleTranslation.getOpenMPBuilder()->createTask(
1806           ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
1807           moduleTranslation.lookupValue(taskOp.getFinal()),
1808           moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
1809           taskOp.getMergeable(),
1810           moduleTranslation.lookupValue(taskOp.getEventHandle()),
1811           moduleTranslation.lookupValue(taskOp.getPriority()));
1812 
1813   if (failed(handleError(afterIP, *taskOp)))
1814     return failure();
1815 
1816   builder.restoreIP(*afterIP);
1817   return success();
1818 }
1819 
1820 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
1821 static LogicalResult
1822 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
1823                       LLVM::ModuleTranslation &moduleTranslation) {
1824   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1825   if (failed(checkImplementationStatus(*tgOp)))
1826     return failure();
1827 
1828   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1829     builder.restoreIP(codegenIP);
1830     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
1831                                builder, moduleTranslation)
1832         .takeError();
1833   };
1834 
1835   InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1836   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1837   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1838       moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
1839                                                             bodyCB);
1840 
1841   if (failed(handleError(afterIP, *tgOp)))
1842     return failure();
1843 
1844   builder.restoreIP(*afterIP);
1845   return success();
1846 }
1847 
1848 static LogicalResult
1849 convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
1850                      LLVM::ModuleTranslation &moduleTranslation) {
1851   if (failed(checkImplementationStatus(*twOp)))
1852     return failure();
1853 
1854   moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
1855   return success();
1856 }
1857 
1858 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
1859 static LogicalResult
1860 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
1861                  LLVM::ModuleTranslation &moduleTranslation) {
1862   auto wsloopOp = cast<omp::WsloopOp>(opInst);
1863   if (failed(checkImplementationStatus(opInst)))
1864     return failure();
1865 
1866   auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
1867   llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
1868   assert(isByRef.size() == wsloopOp.getNumReductionVars());
1869 
1870   // Static is the default.
1871   auto schedule =
1872       wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
1873 
1874   // Find the loop configuration.
1875   llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
1876   llvm::Type *ivType = step->getType();
1877   llvm::Value *chunk = nullptr;
1878   if (wsloopOp.getScheduleChunk()) {
1879     llvm::Value *chunkVar =
1880         moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
1881     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
1882   }
1883 
1884   MutableArrayRef<BlockArgument> privateBlockArgs =
1885       cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
1886   SmallVector<mlir::Value> mlirPrivateVars;
1887   SmallVector<llvm::Value *> llvmPrivateVars;
1888   SmallVector<omp::PrivateClauseOp> privateDecls;
1889   mlirPrivateVars.reserve(privateBlockArgs.size());
1890   llvmPrivateVars.reserve(privateBlockArgs.size());
1891   collectPrivatizationDecls(wsloopOp, privateDecls);
1892 
1893   for (mlir::Value privateVar : wsloopOp.getPrivateVars())
1894     mlirPrivateVars.push_back(privateVar);
1895 
1896   SmallVector<omp::DeclareReductionOp> reductionDecls;
1897   collectReductionDecls(wsloopOp, reductionDecls);
1898   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1899       findAllocaInsertPoint(builder, moduleTranslation);
1900 
1901   SmallVector<llvm::Value *> privateReductionVariables(
1902       wsloopOp.getNumReductionVars());
1903 
1904   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
1905       builder, moduleTranslation, privateBlockArgs, privateDecls,
1906       mlirPrivateVars, llvmPrivateVars, allocaIP);
1907   if (handleError(afterAllocas, opInst).failed())
1908     return failure();
1909 
1910   DenseMap<Value, llvm::Value *> reductionVariableMap;
1911 
1912   MutableArrayRef<BlockArgument> reductionArgs =
1913       cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1914 
1915   SmallVector<DeferredStore> deferredStores;
1916 
1917   if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
1918                                 moduleTranslation, allocaIP, reductionDecls,
1919                                 privateReductionVariables, reductionVariableMap,
1920                                 deferredStores, isByRef)))
1921     return failure();
1922 
1923   if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1924                                   llvmPrivateVars, privateDecls,
1925                                   afterAllocas.get())))
1926     return failure();
1927 
1928   assert(afterAllocas.get()->getSinglePredecessor());
1929   if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
1930                                moduleTranslation,
1931                                afterAllocas.get()->getSinglePredecessor(),
1932                                reductionDecls, privateReductionVariables,
1933                                reductionVariableMap, isByRef, deferredStores)))
1934     return failure();
1935 
1936   // TODO: Replace this with proper composite translation support.
1937   // Currently, all nested wrappers are ignored, so 'do/for simd' will be
1938   // treated the same as a standalone 'do/for'. This is allowed by the spec,
1939   // since it's equivalent to always using a SIMD length of 1.
1940   if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
1941     return failure();
1942 
1943   // Store the mapping between reduction variables and their private copies on
1944   // ModuleTranslation stack. It can be then recovered when translating
1945   // omp.reduce operations in a separate call.
1946   LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1947       moduleTranslation, reductionVariableMap);
1948 
1949   // Set up the source location value for OpenMP runtime.
1950   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1951 
1952   // Generator of the canonical loop body.
1953   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
1954   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
1955   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
1956                      llvm::Value *iv) -> llvm::Error {
1957     // Make sure further conversions know about the induction variable.
1958     moduleTranslation.mapValue(
1959         loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1960 
1961     // Capture the body insertion point for use in nested loops. BodyIP of the
1962     // CanonicalLoopInfo always points to the beginning of the entry block of
1963     // the body.
1964     bodyInsertPoints.push_back(ip);
1965 
1966     if (loopInfos.size() != loopOp.getNumLoops() - 1)
1967       return llvm::Error::success();
1968 
1969     // Convert the body of the loop.
1970     builder.restoreIP(ip);
1971     return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
1972                                moduleTranslation)
1973         .takeError();
1974   };
1975 
1976   // Delegate actual loop construction to the OpenMP IRBuilder.
1977   // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
1978   // loop, i.e. it has a positive step, uses signed integer semantics.
1979   // Reconsider this code when the nested loop operation clearly supports more
1980   // cases.
1981   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1982   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1983     llvm::Value *lowerBound =
1984         moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
1985     llvm::Value *upperBound =
1986         moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
1987     llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
1988 
1989     // Make sure loop trip count are emitted in the preheader of the outermost
1990     // loop at the latest so that they are all available for the new collapsed
1991     // loop will be created below.
1992     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1993     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1994     if (i != 0) {
1995       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
1996       computeIP = loopInfos.front()->getPreheaderIP();
1997     }
1998 
1999     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2000         ompBuilder->createCanonicalLoop(
2001             loc, bodyGen, lowerBound, upperBound, step,
2002             /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2003 
2004     if (failed(handleError(loopResult, *loopOp)))
2005       return failure();
2006 
2007     loopInfos.push_back(*loopResult);
2008   }
2009 
2010   // Collapse loops. Store the insertion point because LoopInfos may get
2011   // invalidated.
2012   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2013   llvm::CanonicalLoopInfo *loopInfo =
2014       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2015 
2016   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2017 
2018   // TODO: Handle doacross loops when the ordered clause has a parameter.
2019   bool isOrdered = wsloopOp.getOrdered().has_value();
2020   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2021   bool isSimd = wsloopOp.getScheduleSimd();
2022 
2023   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2024       ompBuilder->applyWorkshareLoop(
2025           ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
2026           convertToScheduleKind(schedule), chunk, isSimd,
2027           scheduleMod == omp::ScheduleModifier::monotonic,
2028           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
2029 
2030   if (failed(handleError(wsloopIP, opInst)))
2031     return failure();
2032 
2033   // Continue building IR after the loop. Note that the LoopInfo returned by
2034   // `collapseLoops` points inside the outermost loop and is intended for
2035   // potential further loop transformations. Use the insertion point stored
2036   // before collapsing loops instead.
2037   builder.restoreIP(afterIP);
2038 
2039   // Process the reductions if required.
2040   if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
2041                                         allocaIP, reductionDecls,
2042                                         privateReductionVariables, isByRef)))
2043     return failure();
2044 
2045   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
2046                             llvmPrivateVars, privateDecls);
2047 }
2048 
2049 /// Converts the OpenMP parallel operation to LLVM IR.
2050 static LogicalResult
2051 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2052                    LLVM::ModuleTranslation &moduleTranslation) {
2053   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2054   ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
2055   assert(isByRef.size() == opInst.getNumReductionVars());
2056   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2057 
2058   if (failed(checkImplementationStatus(*opInst)))
2059     return failure();
2060 
2061   // Collect delayed privatization declarations
2062   MutableArrayRef<BlockArgument> privateBlockArgs =
2063       cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
2064   SmallVector<mlir::Value> mlirPrivateVars;
2065   SmallVector<llvm::Value *> llvmPrivateVars;
2066   SmallVector<omp::PrivateClauseOp> privateDecls;
2067   mlirPrivateVars.reserve(privateBlockArgs.size());
2068   llvmPrivateVars.reserve(privateBlockArgs.size());
2069   collectPrivatizationDecls(opInst, privateDecls);
2070   for (mlir::Value privateVar : opInst.getPrivateVars())
2071     mlirPrivateVars.push_back(privateVar);
2072 
2073   // Collect reduction declarations
2074   SmallVector<omp::DeclareReductionOp> reductionDecls;
2075   collectReductionDecls(opInst, reductionDecls);
2076   SmallVector<llvm::Value *> privateReductionVariables(
2077       opInst.getNumReductionVars());
2078   SmallVector<DeferredStore> deferredStores;
2079 
2080   auto bodyGenCB = [&](InsertPointTy allocaIP,
2081                        InsertPointTy codeGenIP) -> llvm::Error {
2082     llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2083         builder, moduleTranslation, privateBlockArgs, privateDecls,
2084         mlirPrivateVars, llvmPrivateVars, allocaIP);
2085     if (handleError(afterAllocas, *opInst).failed())
2086       return llvm::make_error<PreviouslyReportedError>();
2087 
2088     // Allocate reduction vars
2089     DenseMap<Value, llvm::Value *> reductionVariableMap;
2090 
2091     MutableArrayRef<BlockArgument> reductionArgs =
2092         cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2093 
2094     allocaIP =
2095         InsertPointTy(allocaIP.getBlock(),
2096                       allocaIP.getBlock()->getTerminator()->getIterator());
2097 
2098     if (failed(allocReductionVars(
2099             opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2100             reductionDecls, privateReductionVariables, reductionVariableMap,
2101             deferredStores, isByRef)))
2102       return llvm::make_error<PreviouslyReportedError>();
2103 
2104     if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
2105                                     llvmPrivateVars, privateDecls,
2106                                     afterAllocas.get())))
2107       return llvm::make_error<PreviouslyReportedError>();
2108 
2109     assert(afterAllocas.get()->getSinglePredecessor());
2110     builder.restoreIP(codeGenIP);
2111 
2112     if (failed(
2113             initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2114                               afterAllocas.get()->getSinglePredecessor(),
2115                               reductionDecls, privateReductionVariables,
2116                               reductionVariableMap, isByRef, deferredStores)))
2117       return llvm::make_error<PreviouslyReportedError>();
2118 
2119     // Store the mapping between reduction variables and their private copies on
2120     // ModuleTranslation stack. It can be then recovered when translating
2121     // omp.reduce operations in a separate call.
2122     LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
2123         moduleTranslation, reductionVariableMap);
2124 
2125     // Save the alloca insertion point on ModuleTranslation stack for use in
2126     // nested regions.
2127     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2128         moduleTranslation, allocaIP);
2129 
2130     // ParallelOp has only one region associated with it.
2131     llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2132         opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
2133     if (!regionBlock)
2134       return regionBlock.takeError();
2135 
2136     // Process the reductions if required.
2137     if (opInst.getNumReductionVars() > 0) {
2138       // Collect reduction info
2139       SmallVector<OwningReductionGen> owningReductionGens;
2140       SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2141       SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2142       collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2143                            owningReductionGens, owningAtomicReductionGens,
2144                            privateReductionVariables, reductionInfos);
2145 
2146       // Move to region cont block
2147       builder.SetInsertPoint((*regionBlock)->getTerminator());
2148 
2149       // Generate reductions from info
2150       llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2151       builder.SetInsertPoint(tempTerminator);
2152 
2153       llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2154           ompBuilder->createReductions(builder.saveIP(), allocaIP,
2155                                        reductionInfos, isByRef, false);
2156       if (!contInsertPoint)
2157         return contInsertPoint.takeError();
2158 
2159       if (!contInsertPoint->getBlock())
2160         return llvm::make_error<PreviouslyReportedError>();
2161 
2162       tempTerminator->eraseFromParent();
2163       builder.restoreIP(*contInsertPoint);
2164     }
2165     return llvm::Error::success();
2166   };
2167 
2168   auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2169                    llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2170     // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
2171     // bodyGenCB.
2172     replVal = &val;
2173     return codeGenIP;
2174   };
2175 
2176   // TODO: Perform finalization actions for variables. This has to be
2177   // called for variables which have destructors/finalizers.
2178   auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2179     InsertPointTy oldIP = builder.saveIP();
2180     builder.restoreIP(codeGenIP);
2181 
2182     // if the reduction has a cleanup region, inline it here to finalize the
2183     // reduction variables
2184     SmallVector<Region *> reductionCleanupRegions;
2185     llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2186                     [](omp::DeclareReductionOp reductionDecl) {
2187                       return &reductionDecl.getCleanupRegion();
2188                     });
2189     if (failed(inlineOmpRegionCleanup(
2190             reductionCleanupRegions, privateReductionVariables,
2191             moduleTranslation, builder, "omp.reduction.cleanup")))
2192       return llvm::createStringError(
2193           "failed to inline `cleanup` region of `omp.declare_reduction`");
2194 
2195     if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
2196                                   llvmPrivateVars, privateDecls)))
2197       return llvm::make_error<PreviouslyReportedError>();
2198 
2199     builder.restoreIP(oldIP);
2200     return llvm::Error::success();
2201   };
2202 
2203   llvm::Value *ifCond = nullptr;
2204   if (auto ifVar = opInst.getIfExpr())
2205     ifCond = moduleTranslation.lookupValue(ifVar);
2206   llvm::Value *numThreads = nullptr;
2207   if (auto numThreadsVar = opInst.getNumThreads())
2208     numThreads = moduleTranslation.lookupValue(numThreadsVar);
2209   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2210   if (auto bind = opInst.getProcBindKind())
2211     pbKind = getProcBindKind(*bind);
2212   // TODO: Is the Parallel construct cancellable?
2213   bool isCancellable = false;
2214 
2215   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2216       findAllocaInsertPoint(builder, moduleTranslation);
2217   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2218 
2219   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2220       ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2221                                  ifCond, numThreads, pbKind, isCancellable);
2222 
2223   if (failed(handleError(afterIP, *opInst)))
2224     return failure();
2225 
2226   builder.restoreIP(*afterIP);
2227   return success();
2228 }
2229 
2230 /// Convert Order attribute to llvm::omp::OrderKind.
2231 static llvm::omp::OrderKind
2232 convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2233   if (!o)
2234     return llvm::omp::OrderKind::OMP_ORDER_unknown;
2235   switch (*o) {
2236   case omp::ClauseOrderKind::Concurrent:
2237     return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2238   }
2239   llvm_unreachable("Unknown ClauseOrderKind kind");
2240 }
2241 
2242 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2243 static LogicalResult
2244 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2245                LLVM::ModuleTranslation &moduleTranslation) {
2246   auto simdOp = cast<omp::SimdOp>(opInst);
2247   auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
2248 
2249   if (failed(checkImplementationStatus(opInst)))
2250     return failure();
2251 
2252   MutableArrayRef<BlockArgument> privateBlockArgs =
2253       cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs();
2254   SmallVector<mlir::Value> mlirPrivateVars;
2255   SmallVector<llvm::Value *> llvmPrivateVars;
2256   SmallVector<omp::PrivateClauseOp> privateDecls;
2257   mlirPrivateVars.reserve(privateBlockArgs.size());
2258   llvmPrivateVars.reserve(privateBlockArgs.size());
2259   collectPrivatizationDecls(simdOp, privateDecls);
2260 
2261   for (mlir::Value privateVar : simdOp.getPrivateVars())
2262     mlirPrivateVars.push_back(privateVar);
2263 
2264   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2265       findAllocaInsertPoint(builder, moduleTranslation);
2266   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2267 
2268   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2269       builder, moduleTranslation, privateBlockArgs, privateDecls,
2270       mlirPrivateVars, llvmPrivateVars, allocaIP);
2271   if (handleError(afterAllocas, opInst).failed())
2272     return failure();
2273 
2274   // Generator of the canonical loop body.
2275   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
2276   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
2277   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2278                      llvm::Value *iv) -> llvm::Error {
2279     // Make sure further conversions know about the induction variable.
2280     moduleTranslation.mapValue(
2281         loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2282 
2283     // Capture the body insertion point for use in nested loops. BodyIP of the
2284     // CanonicalLoopInfo always points to the beginning of the entry block of
2285     // the body.
2286     bodyInsertPoints.push_back(ip);
2287 
2288     if (loopInfos.size() != loopOp.getNumLoops() - 1)
2289       return llvm::Error::success();
2290 
2291     // Convert the body of the loop.
2292     builder.restoreIP(ip);
2293     return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
2294                                moduleTranslation)
2295         .takeError();
2296   };
2297 
2298   // Delegate actual loop construction to the OpenMP IRBuilder.
2299   // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
2300   // loop, i.e. it has a positive step, uses signed integer semantics.
2301   // Reconsider this code when the nested loop operation clearly supports more
2302   // cases.
2303   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2304   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2305     llvm::Value *lowerBound =
2306         moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
2307     llvm::Value *upperBound =
2308         moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
2309     llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
2310 
2311     // Make sure loop trip count are emitted in the preheader of the outermost
2312     // loop at the latest so that they are all available for the new collapsed
2313     // loop will be created below.
2314     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2315     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2316     if (i != 0) {
2317       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2318                                                        ompLoc.DL);
2319       computeIP = loopInfos.front()->getPreheaderIP();
2320     }
2321 
2322     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2323         ompBuilder->createCanonicalLoop(
2324             loc, bodyGen, lowerBound, upperBound, step,
2325             /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
2326 
2327     if (failed(handleError(loopResult, *loopOp)))
2328       return failure();
2329 
2330     loopInfos.push_back(*loopResult);
2331   }
2332 
2333   // Collapse loops.
2334   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2335   llvm::CanonicalLoopInfo *loopInfo =
2336       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2337 
2338   llvm::ConstantInt *simdlen = nullptr;
2339   if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2340     simdlen = builder.getInt64(simdlenVar.value());
2341 
2342   llvm::ConstantInt *safelen = nullptr;
2343   if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2344     safelen = builder.getInt64(safelenVar.value());
2345 
2346   llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2347   llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2348   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2349   std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2350   mlir::OperandRange operands = simdOp.getAlignedVars();
2351   for (size_t i = 0; i < operands.size(); ++i) {
2352     llvm::Value *alignment = nullptr;
2353     llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
2354     llvm::Type *ty = llvmVal->getType();
2355     if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
2356       alignment = builder.getInt64(intAttr.getInt());
2357       assert(ty->isPointerTy() && "Invalid type for aligned variable");
2358       assert(alignment && "Invalid alignment value");
2359       auto curInsert = builder.saveIP();
2360       builder.SetInsertPoint(sourceBlock->getTerminator());
2361       llvmVal = builder.CreateLoad(ty, llvmVal);
2362       builder.restoreIP(curInsert);
2363       alignedVars[llvmVal] = alignment;
2364     }
2365   }
2366   ompBuilder->applySimd(loopInfo, alignedVars,
2367                         simdOp.getIfExpr()
2368                             ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2369                             : nullptr,
2370                         order, simdlen, safelen);
2371 
2372   builder.restoreIP(afterIP);
2373 
2374   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
2375                             llvmPrivateVars, privateDecls);
2376 }
2377 
2378 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
2379 static llvm::AtomicOrdering
2380 convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
2381   if (!ao)
2382     return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
2383 
2384   switch (*ao) {
2385   case omp::ClauseMemoryOrderKind::Seq_cst:
2386     return llvm::AtomicOrdering::SequentiallyConsistent;
2387   case omp::ClauseMemoryOrderKind::Acq_rel:
2388     return llvm::AtomicOrdering::AcquireRelease;
2389   case omp::ClauseMemoryOrderKind::Acquire:
2390     return llvm::AtomicOrdering::Acquire;
2391   case omp::ClauseMemoryOrderKind::Release:
2392     return llvm::AtomicOrdering::Release;
2393   case omp::ClauseMemoryOrderKind::Relaxed:
2394     return llvm::AtomicOrdering::Monotonic;
2395   }
2396   llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
2397 }
2398 
2399 /// Convert omp.atomic.read operation to LLVM IR.
2400 static LogicalResult
2401 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
2402                      LLVM::ModuleTranslation &moduleTranslation) {
2403   auto readOp = cast<omp::AtomicReadOp>(opInst);
2404   if (failed(checkImplementationStatus(opInst)))
2405     return failure();
2406 
2407   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2408 
2409   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2410 
2411   llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
2412   llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
2413   llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
2414 
2415   llvm::Type *elementType =
2416       moduleTranslation.convertType(readOp.getElementType());
2417 
2418   llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
2419   llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
2420   builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
2421   return success();
2422 }
2423 
2424 /// Converts an omp.atomic.write operation to LLVM IR.
2425 static LogicalResult
2426 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
2427                       LLVM::ModuleTranslation &moduleTranslation) {
2428   auto writeOp = cast<omp::AtomicWriteOp>(opInst);
2429   if (failed(checkImplementationStatus(opInst)))
2430     return failure();
2431 
2432   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2433 
2434   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2435   llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
2436   llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
2437   llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
2438   llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
2439   llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
2440                                             /*isVolatile=*/false};
2441   builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
2442   return success();
2443 }
2444 
2445 /// Converts an LLVM dialect binary operation to the corresponding enum value
2446 /// for `atomicrmw` supported binary operation.
2447 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
2448   return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
2449       .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
2450       .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
2451       .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
2452       .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
2453       .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
2454       .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
2455       .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
2456       .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
2457       .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
2458       .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
2459 }
2460 
2461 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
2462 static LogicalResult
2463 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
2464                        llvm::IRBuilderBase &builder,
2465                        LLVM::ModuleTranslation &moduleTranslation) {
2466   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2467   if (failed(checkImplementationStatus(*opInst)))
2468     return failure();
2469 
2470   // Convert values and types.
2471   auto &innerOpList = opInst.getRegion().front().getOperations();
2472   bool isXBinopExpr{false};
2473   llvm::AtomicRMWInst::BinOp binop;
2474   mlir::Value mlirExpr;
2475   llvm::Value *llvmExpr = nullptr;
2476   llvm::Value *llvmX = nullptr;
2477   llvm::Type *llvmXElementType = nullptr;
2478   if (innerOpList.size() == 2) {
2479     // The two operations here are the update and the terminator.
2480     // Since we can identify the update operation, there is a possibility
2481     // that we can generate the atomicrmw instruction.
2482     mlir::Operation &innerOp = *opInst.getRegion().front().begin();
2483     if (!llvm::is_contained(innerOp.getOperands(),
2484                             opInst.getRegion().getArgument(0))) {
2485       return opInst.emitError("no atomic update operation with region argument"
2486                               " as operand found inside atomic.update region");
2487     }
2488     binop = convertBinOpToAtomic(innerOp);
2489     isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
2490     mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
2491     llvmExpr = moduleTranslation.lookupValue(mlirExpr);
2492   } else {
2493     // Since the update region includes more than one operation
2494     // we will resort to generating a cmpxchg loop.
2495     binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2496   }
2497   llvmX = moduleTranslation.lookupValue(opInst.getX());
2498   llvmXElementType = moduleTranslation.convertType(
2499       opInst.getRegion().getArgument(0).getType());
2500   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2501                                                       /*isSigned=*/false,
2502                                                       /*isVolatile=*/false};
2503 
2504   llvm::AtomicOrdering atomicOrdering =
2505       convertAtomicOrdering(opInst.getMemoryOrder());
2506 
2507   // Generate update code.
2508   auto updateFn =
2509       [&opInst, &moduleTranslation](
2510           llvm::Value *atomicx,
2511           llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
2512     Block &bb = *opInst.getRegion().begin();
2513     moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
2514     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
2515     if (failed(moduleTranslation.convertBlock(bb, true, builder)))
2516       return llvm::make_error<PreviouslyReportedError>();
2517 
2518     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
2519     assert(yieldop && yieldop.getResults().size() == 1 &&
2520            "terminator must be omp.yield op and it must have exactly one "
2521            "argument");
2522     return moduleTranslation.lookupValue(yieldop.getResults()[0]);
2523   };
2524 
2525   // Handle ambiguous alloca, if any.
2526   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2527   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2528   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2529       ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
2530                                      atomicOrdering, binop, updateFn,
2531                                      isXBinopExpr);
2532 
2533   if (failed(handleError(afterIP, *opInst)))
2534     return failure();
2535 
2536   builder.restoreIP(*afterIP);
2537   return success();
2538 }
2539 
2540 static LogicalResult
2541 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
2542                         llvm::IRBuilderBase &builder,
2543                         LLVM::ModuleTranslation &moduleTranslation) {
2544   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2545   if (failed(checkImplementationStatus(*atomicCaptureOp)))
2546     return failure();
2547 
2548   mlir::Value mlirExpr;
2549   bool isXBinopExpr = false, isPostfixUpdate = false;
2550   llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2551 
2552   omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
2553   omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
2554 
2555   assert((atomicUpdateOp || atomicWriteOp) &&
2556          "internal op must be an atomic.update or atomic.write op");
2557 
2558   if (atomicWriteOp) {
2559     isPostfixUpdate = true;
2560     mlirExpr = atomicWriteOp.getExpr();
2561   } else {
2562     isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
2563                       atomicCaptureOp.getAtomicUpdateOp().getOperation();
2564     auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
2565     // Find the binary update operation that uses the region argument
2566     // and get the expression to update
2567     if (innerOpList.size() == 2) {
2568       mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
2569       if (!llvm::is_contained(innerOp.getOperands(),
2570                               atomicUpdateOp.getRegion().getArgument(0))) {
2571         return atomicUpdateOp.emitError(
2572             "no atomic update operation with region argument"
2573             " as operand found inside atomic.update region");
2574       }
2575       binop = convertBinOpToAtomic(innerOp);
2576       isXBinopExpr =
2577           innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
2578       mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
2579     } else {
2580       binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2581     }
2582   }
2583 
2584   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
2585   llvm::Value *llvmX =
2586       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
2587   llvm::Value *llvmV =
2588       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
2589   llvm::Type *llvmXElementType = moduleTranslation.convertType(
2590       atomicCaptureOp.getAtomicReadOp().getElementType());
2591   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2592                                                       /*isSigned=*/false,
2593                                                       /*isVolatile=*/false};
2594   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
2595                                                       /*isSigned=*/false,
2596                                                       /*isVolatile=*/false};
2597 
2598   llvm::AtomicOrdering atomicOrdering =
2599       convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
2600 
2601   auto updateFn =
2602       [&](llvm::Value *atomicx,
2603           llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
2604     if (atomicWriteOp)
2605       return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
2606     Block &bb = *atomicUpdateOp.getRegion().begin();
2607     moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
2608                                atomicx);
2609     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
2610     if (failed(moduleTranslation.convertBlock(bb, true, builder)))
2611       return llvm::make_error<PreviouslyReportedError>();
2612 
2613     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
2614     assert(yieldop && yieldop.getResults().size() == 1 &&
2615            "terminator must be omp.yield op and it must have exactly one "
2616            "argument");
2617     return moduleTranslation.lookupValue(yieldop.getResults()[0]);
2618   };
2619 
2620   // Handle ambiguous alloca, if any.
2621   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2622   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2623   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2624       ompBuilder->createAtomicCapture(
2625           ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
2626           binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
2627 
2628   if (failed(handleError(afterIP, *atomicCaptureOp)))
2629     return failure();
2630 
2631   builder.restoreIP(*afterIP);
2632   return success();
2633 }
2634 
2635 /// Converts an OpenMP Threadprivate operation into LLVM IR using
2636 /// OpenMPIRBuilder.
2637 static LogicalResult
2638 convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
2639                         LLVM::ModuleTranslation &moduleTranslation) {
2640   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2641   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2642   auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
2643 
2644   if (failed(checkImplementationStatus(opInst)))
2645     return failure();
2646 
2647   Value symAddr = threadprivateOp.getSymAddr();
2648   auto *symOp = symAddr.getDefiningOp();
2649 
2650   if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
2651     symOp = asCast.getOperand().getDefiningOp();
2652 
2653   if (!isa<LLVM::AddressOfOp>(symOp))
2654     return opInst.emitError("Addressing symbol not found");
2655   LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
2656 
2657   LLVM::GlobalOp global =
2658       addressOfOp.getGlobal(moduleTranslation.symbolTable());
2659   llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
2660 
2661   if (!ompBuilder->Config.isTargetDevice()) {
2662     llvm::Type *type = globalValue->getValueType();
2663     llvm::TypeSize typeSize =
2664         builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
2665             type);
2666     llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
2667     llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
2668         ompLoc, globalValue, size, global.getSymName() + ".cache");
2669     moduleTranslation.mapValue(opInst.getResult(0), callInst);
2670   } else {
2671     moduleTranslation.mapValue(opInst.getResult(0), globalValue);
2672   }
2673 
2674   return success();
2675 }
2676 
2677 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
2678 convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
2679   switch (deviceClause) {
2680   case mlir::omp::DeclareTargetDeviceType::host:
2681     return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
2682     break;
2683   case mlir::omp::DeclareTargetDeviceType::nohost:
2684     return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
2685     break;
2686   case mlir::omp::DeclareTargetDeviceType::any:
2687     return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
2688     break;
2689   }
2690   llvm_unreachable("unhandled device clause");
2691 }
2692 
2693 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
2694 convertToCaptureClauseKind(
2695     mlir::omp::DeclareTargetCaptureClause captureClause) {
2696   switch (captureClause) {
2697   case mlir::omp::DeclareTargetCaptureClause::to:
2698     return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
2699   case mlir::omp::DeclareTargetCaptureClause::link:
2700     return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
2701   case mlir::omp::DeclareTargetCaptureClause::enter:
2702     return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
2703   }
2704   llvm_unreachable("unhandled capture clause");
2705 }
2706 
2707 static llvm::SmallString<64>
2708 getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
2709                              llvm::OpenMPIRBuilder &ompBuilder) {
2710   llvm::SmallString<64> suffix;
2711   llvm::raw_svector_ostream os(suffix);
2712   if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
2713     auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
2714     auto fileInfoCallBack = [&loc]() {
2715       return std::pair<std::string, uint64_t>(
2716           llvm::StringRef(loc.getFilename()), loc.getLine());
2717     };
2718 
2719     os << llvm::format(
2720         "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
2721   }
2722   os << "_decl_tgt_ref_ptr";
2723 
2724   return suffix;
2725 }
2726 
2727 static bool isDeclareTargetLink(mlir::Value value) {
2728   if (auto addressOfOp =
2729           llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
2730     auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
2731     Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
2732     if (auto declareTargetGlobal =
2733             llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
2734       if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2735           mlir::omp::DeclareTargetCaptureClause::link)
2736         return true;
2737   }
2738   return false;
2739 }
2740 
2741 // Returns the reference pointer generated by the lowering of the declare target
2742 // operation in cases where the link clause is used or the to clause is used in
2743 // USM mode.
2744 static llvm::Value *
2745 getRefPtrIfDeclareTarget(mlir::Value value,
2746                          LLVM::ModuleTranslation &moduleTranslation) {
2747   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2748 
2749   // An easier way to do this may just be to keep track of any pointer
2750   // references and their mapping to their respective operation
2751   if (auto addressOfOp =
2752           llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
2753     if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
2754             addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
2755                 addressOfOp.getGlobalName()))) {
2756 
2757       if (auto declareTargetGlobal =
2758               llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
2759                   gOp.getOperation())) {
2760 
2761         // In this case, we must utilise the reference pointer generated by the
2762         // declare target operation, similar to Clang
2763         if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
2764              mlir::omp::DeclareTargetCaptureClause::link) ||
2765             (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2766                  mlir::omp::DeclareTargetCaptureClause::to &&
2767              ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
2768           llvm::SmallString<64> suffix =
2769               getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
2770 
2771           if (gOp.getSymName().contains(suffix))
2772             return moduleTranslation.getLLVMModule()->getNamedValue(
2773                 gOp.getSymName());
2774 
2775           return moduleTranslation.getLLVMModule()->getNamedValue(
2776               (gOp.getSymName().str() + suffix.str()).str());
2777         }
2778       }
2779     }
2780   }
2781 
2782   return nullptr;
2783 }
2784 
2785 namespace {
2786 // A small helper structure to contain data gathered
2787 // for map lowering and coalese it into one area and
2788 // avoiding extra computations such as searches in the
2789 // llvm module for lowered mapped variables or checking
2790 // if something is declare target (and retrieving the
2791 // value) more than neccessary.
2792 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
2793   llvm::SmallVector<bool, 4> IsDeclareTarget;
2794   llvm::SmallVector<bool, 4> IsAMember;
2795   // Identify if mapping was added by mapClause or use_device clauses.
2796   llvm::SmallVector<bool, 4> IsAMapping;
2797   llvm::SmallVector<mlir::Operation *, 4> MapClause;
2798   llvm::SmallVector<llvm::Value *, 4> OriginalValue;
2799   // Stripped off array/pointer to get the underlying
2800   // element type
2801   llvm::SmallVector<llvm::Type *, 4> BaseType;
2802 
2803   /// Append arrays in \a CurInfo.
2804   void append(MapInfoData &CurInfo) {
2805     IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
2806                            CurInfo.IsDeclareTarget.end());
2807     MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
2808     OriginalValue.append(CurInfo.OriginalValue.begin(),
2809                          CurInfo.OriginalValue.end());
2810     BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
2811     llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
2812   }
2813 };
2814 } // namespace
2815 
2816 uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
2817   if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
2818           arrTy.getElementType()))
2819     return getArrayElementSizeInBits(nestedArrTy, dl);
2820   return dl.getTypeSizeInBits(arrTy.getElementType());
2821 }
2822 
2823 // This function calculates the size to be offloaded for a specified type, given
2824 // its associated map clause (which can contain bounds information which affects
2825 // the total size), this size is calculated based on the underlying element type
2826 // e.g. given a 1-D array of ints, we will calculate the size from the integer
2827 // type * number of elements in the array. This size can be used in other
2828 // calculations but is ultimately used as an argument to the OpenMP runtimes
2829 // kernel argument structure which is generated through the combinedInfo data
2830 // structures.
2831 // This function is somewhat equivalent to Clang's getExprTypeSize inside of
2832 // CGOpenMPRuntime.cpp.
2833 llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2834                             Operation *clauseOp, llvm::Value *basePointer,
2835                             llvm::Type *baseType, llvm::IRBuilderBase &builder,
2836                             LLVM::ModuleTranslation &moduleTranslation) {
2837   if (auto memberClause =
2838           mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
2839     // This calculates the size to transfer based on bounds and the underlying
2840     // element type, provided bounds have been specified (Fortran
2841     // pointers/allocatables/target and arrays that have sections specified fall
2842     // into this as well).
2843     if (!memberClause.getBounds().empty()) {
2844       llvm::Value *elementCount = builder.getInt64(1);
2845       for (auto bounds : memberClause.getBounds()) {
2846         if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2847                 bounds.getDefiningOp())) {
2848           // The below calculation for the size to be mapped calculated from the
2849           // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
2850           // multiply by the underlying element types byte size to get the full
2851           // size to be offloaded based on the bounds
2852           elementCount = builder.CreateMul(
2853               elementCount,
2854               builder.CreateAdd(
2855                   builder.CreateSub(
2856                       moduleTranslation.lookupValue(boundOp.getUpperBound()),
2857                       moduleTranslation.lookupValue(boundOp.getLowerBound())),
2858                   builder.getInt64(1)));
2859         }
2860       }
2861 
2862       // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
2863       // the size in inconsistent byte or bit format.
2864       uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
2865       if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
2866         underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
2867 
2868       // The size in bytes x number of elements, the sizeInBytes stored is
2869       // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
2870       // size, so we do some on the fly runtime math to get the size in
2871       // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
2872       // some adjustment for members with more complex types.
2873       return builder.CreateMul(elementCount,
2874                                builder.getInt64(underlyingTypeSzInBits / 8));
2875     }
2876   }
2877 
2878   return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
2879 }
2880 
2881 static void collectMapDataFromMapOperands(
2882     MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
2883     LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2884     llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {},
2885     const ArrayRef<Value> &useDevAddrOperands = {}) {
2886   auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
2887     // Check if this is a member mapping and correctly assign that it is, if
2888     // it is a member of a larger object.
2889     // TODO: Need better handling of members, and distinguishing of members
2890     // that are implicitly allocated on device vs explicitly passed in as
2891     // arguments.
2892     // TODO: May require some further additions to support nested record
2893     // types, i.e. member maps that can have member maps.
2894     for (Value mapValue : mapVars) {
2895       auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2896       for (auto member : map.getMembers())
2897         if (member == mapOp)
2898           return true;
2899     }
2900     return false;
2901   };
2902 
2903   // Process MapOperands
2904   for (Value mapValue : mapVars) {
2905     auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2906     Value offloadPtr =
2907         mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2908     mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
2909     mapData.Pointers.push_back(mapData.OriginalValue.back());
2910 
2911     if (llvm::Value *refPtr =
2912             getRefPtrIfDeclareTarget(offloadPtr,
2913                                      moduleTranslation)) { // declare target
2914       mapData.IsDeclareTarget.push_back(true);
2915       mapData.BasePointers.push_back(refPtr);
2916     } else { // regular mapped variable
2917       mapData.IsDeclareTarget.push_back(false);
2918       mapData.BasePointers.push_back(mapData.OriginalValue.back());
2919     }
2920 
2921     mapData.BaseType.push_back(
2922         moduleTranslation.convertType(mapOp.getVarType()));
2923     mapData.Sizes.push_back(
2924         getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2925                        mapData.BaseType.back(), builder, moduleTranslation));
2926     mapData.MapClause.push_back(mapOp.getOperation());
2927     mapData.Types.push_back(
2928         llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2929     mapData.Names.push_back(LLVM::createMappingInformation(
2930         mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2931     mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2932     mapData.IsAMapping.push_back(true);
2933     mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
2934   }
2935 
2936   auto findMapInfo = [&mapData](llvm::Value *val,
2937                                 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2938     unsigned index = 0;
2939     bool found = false;
2940     for (llvm::Value *basePtr : mapData.OriginalValue) {
2941       if (basePtr == val && mapData.IsAMapping[index]) {
2942         found = true;
2943         mapData.Types[index] |=
2944             llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2945         mapData.DevicePointers[index] = devInfoTy;
2946       }
2947       index++;
2948     }
2949     return found;
2950   };
2951 
2952   // Process useDevPtr(Addr)Operands
2953   auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
2954                          llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2955     for (Value mapValue : useDevOperands) {
2956       auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2957       Value offloadPtr =
2958           mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2959       llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
2960 
2961       // Check if map info is already present for this entry.
2962       if (!findMapInfo(origValue, devInfoTy)) {
2963         mapData.OriginalValue.push_back(origValue);
2964         mapData.Pointers.push_back(mapData.OriginalValue.back());
2965         mapData.IsDeclareTarget.push_back(false);
2966         mapData.BasePointers.push_back(mapData.OriginalValue.back());
2967         mapData.BaseType.push_back(
2968             moduleTranslation.convertType(mapOp.getVarType()));
2969         mapData.Sizes.push_back(builder.getInt64(0));
2970         mapData.MapClause.push_back(mapOp.getOperation());
2971         mapData.Types.push_back(
2972             llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2973         mapData.Names.push_back(LLVM::createMappingInformation(
2974             mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2975         mapData.DevicePointers.push_back(devInfoTy);
2976         mapData.IsAMapping.push_back(false);
2977         mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
2978       }
2979     }
2980   };
2981 
2982   addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2983   addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2984 }
2985 
2986 static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
2987   auto *res = llvm::find(mapData.MapClause, memberOp);
2988   assert(res != mapData.MapClause.end() &&
2989          "MapInfoOp for member not found in MapData, cannot return index");
2990   return std::distance(mapData.MapClause.begin(), res);
2991 }
2992 
2993 static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2994                                                     bool first) {
2995   ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2996   // Only 1 member has been mapped, we can return it.
2997   if (indexAttr.size() == 1)
2998     return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
2999 
3000   llvm::SmallVector<size_t> indices(indexAttr.size());
3001   std::iota(indices.begin(), indices.end(), 0);
3002 
3003   llvm::sort(indices.begin(), indices.end(),
3004              [&](const size_t a, const size_t b) {
3005                auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3006                auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3007                for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3008                  int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3009                  int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3010 
3011                  if (aIndex == bIndex)
3012                    continue;
3013 
3014                  if (aIndex < bIndex)
3015                    return first;
3016 
3017                  if (aIndex > bIndex)
3018                    return !first;
3019                }
3020 
3021                // Iterated the up until the end of the smallest member and
3022                // they were found to be equal up to that point, so select
3023                // the member with the lowest index count, so the "parent"
3024                return memberIndicesA.size() < memberIndicesB.size();
3025              });
3026 
3027   return llvm::cast<omp::MapInfoOp>(
3028       mapInfo.getMembers()[indices.front()].getDefiningOp());
3029 }
3030 
3031 /// This function calculates the array/pointer offset for map data provided
3032 /// with bounds operations, e.g. when provided something like the following:
3033 ///
3034 /// Fortran
3035 ///     map(tofrom: array(2:5, 3:2))
3036 ///   or
3037 /// C++
3038 ///   map(tofrom: array[1:4][2:3])
3039 /// We must calculate the initial pointer offset to pass across, this function
3040 /// performs this using bounds.
3041 ///
3042 /// NOTE: which while specified in row-major order it currently needs to be
3043 /// flipped for Fortran's column order array allocation and access (as
3044 /// opposed to C++'s row-major, hence the backwards processing where order is
3045 /// important). This is likely important to keep in mind for the future when
3046 /// we incorporate a C++ frontend, both frontends will need to agree on the
3047 /// ordering of generated bounds operations (one may have to flip them) to
3048 /// make the below lowering frontend agnostic. The offload size
3049 /// calcualtion may also have to be adjusted for C++.
3050 std::vector<llvm::Value *>
3051 calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3052                       llvm::IRBuilderBase &builder, bool isArrayTy,
3053                       OperandRange bounds) {
3054   std::vector<llvm::Value *> idx;
3055   // There's no bounds to calculate an offset from, we can safely
3056   // ignore and return no indices.
3057   if (bounds.empty())
3058     return idx;
3059 
3060   // If we have an array type, then we have its type so can treat it as a
3061   // normal GEP instruction where the bounds operations are simply indexes
3062   // into the array. We currently do reverse order of the bounds, which
3063   // I believe leans more towards Fortran's column-major in memory.
3064   if (isArrayTy) {
3065     idx.push_back(builder.getInt64(0));
3066     for (int i = bounds.size() - 1; i >= 0; --i) {
3067       if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3068               bounds[i].getDefiningOp())) {
3069         idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
3070       }
3071     }
3072   } else {
3073     // If we do not have an array type, but we have bounds, then we're dealing
3074     // with a pointer that's being treated like an array and we have the
3075     // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
3076     // address (pointer pointing to the actual data) so we must caclulate the
3077     // offset using a single index which the following two loops attempts to
3078     // compute.
3079 
3080     // Calculates the size offset we need to make per row e.g. first row or
3081     // column only needs to be offset by one, but the next would have to be
3082     // the previous row/column offset multiplied by the extent of current row.
3083     //
3084     // For example ([1][10][100]):
3085     //
3086     //  - First row/column we move by 1 for each index increment
3087     //  - Second row/column we move by 1 (first row/column) * 10 (extent/size of
3088     //  current) for 10 for each index increment
3089     //  - Third row/column we would move by 10 (second row/column) *
3090     //  (extent/size of current) 100 for 1000 for each index increment
3091     std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3092     for (size_t i = 1; i < bounds.size(); ++i) {
3093       if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3094               bounds[i].getDefiningOp())) {
3095         dimensionIndexSizeOffset.push_back(builder.CreateMul(
3096             moduleTranslation.lookupValue(boundOp.getExtent()),
3097             dimensionIndexSizeOffset[i - 1]));
3098       }
3099     }
3100 
3101     // Now that we have calculated how much we move by per index, we must
3102     // multiply each lower bound offset in indexes by the size offset we
3103     // have calculated in the previous and accumulate the results to get
3104     // our final resulting offset.
3105     for (int i = bounds.size() - 1; i >= 0; --i) {
3106       if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3107               bounds[i].getDefiningOp())) {
3108         if (idx.empty())
3109           idx.emplace_back(builder.CreateMul(
3110               moduleTranslation.lookupValue(boundOp.getLowerBound()),
3111               dimensionIndexSizeOffset[i]));
3112         else
3113           idx.back() = builder.CreateAdd(
3114               idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
3115                                                 boundOp.getLowerBound()),
3116                                             dimensionIndexSizeOffset[i]));
3117       }
3118     }
3119   }
3120 
3121   return idx;
3122 }
3123 
3124 // This creates two insertions into the MapInfosTy data structure for the
3125 // "parent" of a set of members, (usually a container e.g.
3126 // class/structure/derived type) when subsequent members have also been
3127 // explicitly mapped on the same map clause. Certain types, such as Fortran
3128 // descriptors are mapped like this as well, however, the members are
3129 // implicit as far as a user is concerned, but we must explicitly map them
3130 // internally.
3131 //
3132 // This function also returns the memberOfFlag for this particular parent,
3133 // which is utilised in subsequent member mappings (by modifying there map type
3134 // with it) to indicate that a member is part of this parent and should be
3135 // treated by the runtime as such. Important to achieve the correct mapping.
3136 //
3137 // This function borrows a lot from Clang's emitCombinedEntry function
3138 // inside of CGOpenMPRuntime.cpp
3139 static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3140     LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3141     llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3142     llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3143     uint64_t mapDataIndex, bool isTargetParams) {
3144   // Map the first segment of our structure
3145   combinedInfo.Types.emplace_back(
3146       isTargetParams
3147           ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3148           : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3149   combinedInfo.DevicePointers.emplace_back(
3150       mapData.DevicePointers[mapDataIndex]);
3151   combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3152       mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3153   combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3154 
3155   // Calculate size of the parent object being mapped based on the
3156   // addresses at runtime, highAddr - lowAddr = size. This of course
3157   // doesn't factor in allocated data like pointers, hence the further
3158   // processing of members specified by users, or in the case of
3159   // Fortran pointers and allocatables, the mapping of the pointed to
3160   // data by the descriptor (which itself, is a structure containing
3161   // runtime information on the dynamically allocated data).
3162   auto parentClause =
3163       llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3164 
3165   llvm::Value *lowAddr, *highAddr;
3166   if (!parentClause.getPartialMap()) {
3167     lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3168                                         builder.getPtrTy());
3169     highAddr = builder.CreatePointerCast(
3170         builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3171                                    mapData.Pointers[mapDataIndex], 1),
3172         builder.getPtrTy());
3173     combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3174   } else {
3175     auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3176     int firstMemberIdx = getMapDataMemberIdx(
3177         mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
3178     lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3179                                         builder.getPtrTy());
3180     int lastMemberIdx = getMapDataMemberIdx(
3181         mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
3182     highAddr = builder.CreatePointerCast(
3183         builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3184                           mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3185         builder.getPtrTy());
3186     combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3187   }
3188 
3189   llvm::Value *size = builder.CreateIntCast(
3190       builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3191       builder.getInt64Ty(),
3192       /*isSigned=*/false);
3193   combinedInfo.Sizes.push_back(size);
3194 
3195   llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3196       ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3197 
3198   // This creates the initial MEMBER_OF mapping that consists of
3199   // the parent/top level container (same as above effectively, except
3200   // with a fixed initial compile time size and separate maptype which
3201   // indicates the true mape type (tofrom etc.). This parent mapping is
3202   // only relevant if the structure in its totality is being mapped,
3203   // otherwise the above suffices.
3204   if (!parentClause.getPartialMap()) {
3205     // TODO: This will need to be expanded to include the whole host of logic
3206     // for the map flags that Clang currently supports (e.g. it should do some
3207     // further case specific flag modifications). For the moment, it handles
3208     // what we support as expected.
3209     llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3210     ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3211     combinedInfo.Types.emplace_back(mapFlag);
3212     combinedInfo.DevicePointers.emplace_back(
3213         llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3214     combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3215         mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3216     combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3217     combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3218     combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3219   }
3220   return memberOfFlag;
3221 }
3222 
3223 // The intent is to verify if the mapped data being passed is a
3224 // pointer -> pointee that requires special handling in certain cases,
3225 // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3226 //
3227 // There may be a better way to verify this, but unfortunately with
3228 // opaque pointers we lose the ability to easily check if something is
3229 // a pointer whilst maintaining access to the underlying type.
3230 static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3231   // If we have a varPtrPtr field assigned then the underlying type is a pointer
3232   if (mapOp.getVarPtrPtr())
3233     return true;
3234 
3235   // If the map data is declare target with a link clause, then it's represented
3236   // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3237   // no relation to pointers.
3238   if (isDeclareTargetLink(mapOp.getVarPtr()))
3239     return true;
3240 
3241   return false;
3242 }
3243 
3244 // This function is intended to add explicit mappings of members
3245 static void processMapMembersWithParent(
3246     LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3247     llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3248     llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3249     uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
3250 
3251   auto parentClause =
3252       llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3253 
3254   for (auto mappedMembers : parentClause.getMembers()) {
3255     auto memberClause =
3256         llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
3257     int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
3258 
3259     assert(memberDataIdx >= 0 && "could not find mapped member of structure");
3260 
3261     // If we're currently mapping a pointer to a block of data, we must
3262     // initially map the pointer, and then attatch/bind the data with a
3263     // subsequent map to the pointer. This segment of code generates the
3264     // pointer mapping, which can in certain cases be optimised out as Clang
3265     // currently does in its lowering. However, for the moment we do not do so,
3266     // in part as we currently have substantially less information on the data
3267     // being mapped at this stage.
3268     if (checkIfPointerMap(memberClause)) {
3269       auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
3270           memberClause.getMapType().value());
3271       mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3272       mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3273       ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3274       combinedInfo.Types.emplace_back(mapFlag);
3275       combinedInfo.DevicePointers.emplace_back(
3276           llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3277       combinedInfo.Names.emplace_back(
3278           LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
3279       combinedInfo.BasePointers.emplace_back(
3280           mapData.BasePointers[mapDataIndex]);
3281       combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
3282       combinedInfo.Sizes.emplace_back(builder.getInt64(
3283           moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
3284     }
3285 
3286     // Same MemberOfFlag to indicate its link with parent and other members
3287     // of.
3288     auto mapFlag =
3289         llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
3290     mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3291     mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3292     ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3293     if (checkIfPointerMap(memberClause))
3294       mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3295 
3296     combinedInfo.Types.emplace_back(mapFlag);
3297     combinedInfo.DevicePointers.emplace_back(
3298         mapData.DevicePointers[memberDataIdx]);
3299     combinedInfo.Names.emplace_back(
3300         LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
3301     uint64_t basePointerIndex =
3302         checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
3303     combinedInfo.BasePointers.emplace_back(
3304         mapData.BasePointers[basePointerIndex]);
3305     combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
3306 
3307     llvm::Value *size = mapData.Sizes[memberDataIdx];
3308     if (checkIfPointerMap(memberClause)) {
3309       size = builder.CreateSelect(
3310           builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
3311           builder.getInt64(0), size);
3312     }
3313 
3314     combinedInfo.Sizes.emplace_back(size);
3315   }
3316 }
3317 
3318 static void
3319 processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
3320                      llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3321                      bool isTargetParams, int mapDataParentIdx = -1) {
3322   // Declare Target Mappings are excluded from being marked as
3323   // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
3324   // marked with OMP_MAP_PTR_AND_OBJ instead.
3325   auto mapFlag = mapData.Types[mapDataIdx];
3326   auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
3327 
3328   bool isPtrTy = checkIfPointerMap(mapInfoOp);
3329   if (isPtrTy)
3330     mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3331 
3332   if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
3333     mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3334 
3335   if (mapInfoOp.getMapCaptureType().value() ==
3336           omp::VariableCaptureKind::ByCopy &&
3337       !isPtrTy)
3338     mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3339 
3340   // if we're provided a mapDataParentIdx, then the data being mapped is
3341   // part of a larger object (in a parent <-> member mapping) and in this
3342   // case our BasePointer should be the parent.
3343   if (mapDataParentIdx >= 0)
3344     combinedInfo.BasePointers.emplace_back(
3345         mapData.BasePointers[mapDataParentIdx]);
3346   else
3347     combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
3348 
3349   combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
3350   combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
3351   combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
3352   combinedInfo.Types.emplace_back(mapFlag);
3353   combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
3354 }
3355 
3356 static void processMapWithMembersOf(
3357     LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3358     llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3359     llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3360     uint64_t mapDataIndex, bool isTargetParams) {
3361   auto parentClause =
3362       llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3363 
3364   // If we have a partial map (no parent referenced in the map clauses of the
3365   // directive, only members) and only a single member, we do not need to bind
3366   // the map of the member to the parent, we can pass the member separately.
3367   if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
3368     auto memberClause = llvm::cast<omp::MapInfoOp>(
3369         parentClause.getMembers()[0].getDefiningOp());
3370     int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
3371     // Note: Clang treats arrays with explicit bounds that fall into this
3372     // category as a parent with map case, however, it seems this isn't a
3373     // requirement, and processing them as an individual map is fine. So,
3374     // we will handle them as individual maps for the moment, as it's
3375     // difficult for us to check this as we always require bounds to be
3376     // specified currently and it's also marginally more optimal (single
3377     // map rather than two). The difference may come from the fact that
3378     // Clang maps array without bounds as pointers (which we do not
3379     // currently do), whereas we treat them as arrays in all cases
3380     // currently.
3381     processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
3382                          mapDataIndex);
3383     return;
3384   }
3385 
3386   llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
3387       mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
3388                            combinedInfo, mapData, mapDataIndex, isTargetParams);
3389   processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
3390                               combinedInfo, mapData, mapDataIndex,
3391                               memberOfParentFlag);
3392 }
3393 
3394 // This is a variation on Clang's GenerateOpenMPCapturedVars, which
3395 // generates different operation (e.g. load/store) combinations for
3396 // arguments to the kernel, based on map capture kinds which are then
3397 // utilised in the combinedInfo in place of the original Map value.
3398 static void
3399 createAlteredByCaptureMap(MapInfoData &mapData,
3400                           LLVM::ModuleTranslation &moduleTranslation,
3401                           llvm::IRBuilderBase &builder) {
3402   for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3403     // if it's declare target, skip it, it's handled separately.
3404     if (!mapData.IsDeclareTarget[i]) {
3405       auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3406       omp::VariableCaptureKind captureKind =
3407           mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3408       bool isPtrTy = checkIfPointerMap(mapOp);
3409 
3410       // Currently handles array sectioning lowerbound case, but more
3411       // logic may be required in the future. Clang invokes EmitLValue,
3412       // which has specialised logic for special Clang types such as user
3413       // defines, so it is possible we will have to extend this for
3414       // structures or other complex types. As the general idea is that this
3415       // function mimics some of the logic from Clang that we require for
3416       // kernel argument passing from host -> device.
3417       switch (captureKind) {
3418       case omp::VariableCaptureKind::ByRef: {
3419         llvm::Value *newV = mapData.Pointers[i];
3420         std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
3421             moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
3422             mapOp.getBounds());
3423         if (isPtrTy)
3424           newV = builder.CreateLoad(builder.getPtrTy(), newV);
3425 
3426         if (!offsetIdx.empty())
3427           newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
3428                                            "array_offset");
3429         mapData.Pointers[i] = newV;
3430       } break;
3431       case omp::VariableCaptureKind::ByCopy: {
3432         llvm::Type *type = mapData.BaseType[i];
3433         llvm::Value *newV;
3434         if (mapData.Pointers[i]->getType()->isPointerTy())
3435           newV = builder.CreateLoad(type, mapData.Pointers[i]);
3436         else
3437           newV = mapData.Pointers[i];
3438 
3439         if (!isPtrTy) {
3440           auto curInsert = builder.saveIP();
3441           builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
3442           auto *memTempAlloc =
3443               builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
3444           builder.restoreIP(curInsert);
3445 
3446           builder.CreateStore(newV, memTempAlloc);
3447           newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
3448         }
3449 
3450         mapData.Pointers[i] = newV;
3451         mapData.BasePointers[i] = newV;
3452       } break;
3453       case omp::VariableCaptureKind::This:
3454       case omp::VariableCaptureKind::VLAType:
3455         mapData.MapClause[i]->emitOpError("Unhandled capture kind");
3456         break;
3457       }
3458     }
3459   }
3460 }
3461 
3462 // Generate all map related information and fill the combinedInfo.
3463 static void genMapInfos(llvm::IRBuilderBase &builder,
3464                         LLVM::ModuleTranslation &moduleTranslation,
3465                         DataLayout &dl,
3466                         llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3467                         MapInfoData &mapData, bool isTargetParams = false) {
3468   // We wish to modify some of the methods in which arguments are
3469   // passed based on their capture type by the target region, this can
3470   // involve generating new loads and stores, which changes the
3471   // MLIR value to LLVM value mapping, however, we only wish to do this
3472   // locally for the current function/target and also avoid altering
3473   // ModuleTranslation, so we remap the base pointer or pointer stored
3474   // in the map infos corresponding MapInfoData, which is later accessed
3475   // by genMapInfos and createTarget to help generate the kernel and
3476   // kernel arg structure. It primarily becomes relevant in cases like
3477   // bycopy, or byref range'd arrays. In the default case, we simply
3478   // pass thee pointer byref as both basePointer and pointer.
3479   if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
3480     createAlteredByCaptureMap(mapData, moduleTranslation, builder);
3481 
3482   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3483 
3484   // We operate under the assumption that all vectors that are
3485   // required in MapInfoData are of equal lengths (either filled with
3486   // default constructed data or appropiate information) so we can
3487   // utilise the size from any component of MapInfoData, if we can't
3488   // something is missing from the initial MapInfoData construction.
3489   for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3490     // NOTE/TODO: We currently do not support arbitrary depth record
3491     // type mapping.
3492     if (mapData.IsAMember[i])
3493       continue;
3494 
3495     auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
3496     if (!mapInfoOp.getMembers().empty()) {
3497       processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
3498                               combinedInfo, mapData, i, isTargetParams);
3499       continue;
3500     }
3501 
3502     processIndividualMap(mapData, i, combinedInfo, isTargetParams);
3503   }
3504 }
3505 
3506 static LogicalResult
3507 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3508                      LLVM::ModuleTranslation &moduleTranslation) {
3509   llvm::Value *ifCond = nullptr;
3510   int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
3511   SmallVector<Value> mapVars;
3512   SmallVector<Value> useDevicePtrVars;
3513   SmallVector<Value> useDeviceAddrVars;
3514   llvm::omp::RuntimeFunction RTLFn;
3515   DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
3516 
3517   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3518   llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
3519                                              /*SeparateBeginEndCalls=*/true);
3520 
3521   LogicalResult result =
3522       llvm::TypeSwitch<Operation *, LogicalResult>(op)
3523           .Case([&](omp::TargetDataOp dataOp) {
3524             if (failed(checkImplementationStatus(*dataOp)))
3525               return failure();
3526 
3527             if (auto ifVar = dataOp.getIfExpr())
3528               ifCond = moduleTranslation.lookupValue(ifVar);
3529 
3530             if (auto devId = dataOp.getDevice())
3531               if (auto constOp =
3532                       dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3533                 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3534                   deviceID = intAttr.getInt();
3535 
3536             mapVars = dataOp.getMapVars();
3537             useDevicePtrVars = dataOp.getUseDevicePtrVars();
3538             useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
3539             return success();
3540           })
3541           .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
3542             if (failed(checkImplementationStatus(*enterDataOp)))
3543               return failure();
3544 
3545             if (auto ifVar = enterDataOp.getIfExpr())
3546               ifCond = moduleTranslation.lookupValue(ifVar);
3547 
3548             if (auto devId = enterDataOp.getDevice())
3549               if (auto constOp =
3550                       dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3551                 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3552                   deviceID = intAttr.getInt();
3553             RTLFn =
3554                 enterDataOp.getNowait()
3555                     ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
3556                     : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
3557             mapVars = enterDataOp.getMapVars();
3558             info.HasNoWait = enterDataOp.getNowait();
3559             return success();
3560           })
3561           .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
3562             if (failed(checkImplementationStatus(*exitDataOp)))
3563               return failure();
3564 
3565             if (auto ifVar = exitDataOp.getIfExpr())
3566               ifCond = moduleTranslation.lookupValue(ifVar);
3567 
3568             if (auto devId = exitDataOp.getDevice())
3569               if (auto constOp =
3570                       dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3571                 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3572                   deviceID = intAttr.getInt();
3573 
3574             RTLFn = exitDataOp.getNowait()
3575                         ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
3576                         : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
3577             mapVars = exitDataOp.getMapVars();
3578             info.HasNoWait = exitDataOp.getNowait();
3579             return success();
3580           })
3581           .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
3582             if (failed(checkImplementationStatus(*updateDataOp)))
3583               return failure();
3584 
3585             if (auto ifVar = updateDataOp.getIfExpr())
3586               ifCond = moduleTranslation.lookupValue(ifVar);
3587 
3588             if (auto devId = updateDataOp.getDevice())
3589               if (auto constOp =
3590                       dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3591                 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3592                   deviceID = intAttr.getInt();
3593 
3594             RTLFn =
3595                 updateDataOp.getNowait()
3596                     ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
3597                     : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
3598             mapVars = updateDataOp.getMapVars();
3599             info.HasNoWait = updateDataOp.getNowait();
3600             return success();
3601           })
3602           .Default([&](Operation *op) {
3603             llvm_unreachable("unexpected operation");
3604             return failure();
3605           });
3606 
3607   if (failed(result))
3608     return failure();
3609 
3610   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3611 
3612   MapInfoData mapData;
3613   collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
3614                                 builder, useDevicePtrVars, useDeviceAddrVars);
3615 
3616   // Fill up the arrays with all the mapped variables.
3617   llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
3618   auto genMapInfoCB =
3619       [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
3620     builder.restoreIP(codeGenIP);
3621     genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
3622     return combinedInfo;
3623   };
3624 
3625   // Define a lambda to apply mappings between use_device_addr and
3626   // use_device_ptr base pointers, and their associated block arguments.
3627   auto mapUseDevice =
3628       [&moduleTranslation](
3629           llvm::OpenMPIRBuilder::DeviceInfoTy type,
3630           llvm::ArrayRef<BlockArgument> blockArgs,
3631           llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
3632           llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
3633         for (auto [arg, useDevVar] :
3634              llvm::zip_equal(blockArgs, useDeviceVars)) {
3635 
3636           auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
3637             return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
3638                                             : mapInfoOp.getVarPtr();
3639           };
3640 
3641           auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
3642           for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
3643                    mapInfoData.MapClause, mapInfoData.DevicePointers,
3644                    mapInfoData.BasePointers)) {
3645             auto mapOp = cast<omp::MapInfoOp>(mapClause);
3646             if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
3647                 devicePointer != type)
3648               continue;
3649 
3650             if (llvm::Value *devPtrInfoMap =
3651                     mapper ? mapper(basePointer) : basePointer) {
3652               moduleTranslation.mapValue(arg, devPtrInfoMap);
3653               break;
3654             }
3655           }
3656         }
3657       };
3658 
3659   using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
3660   auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
3661       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
3662     assert(isa<omp::TargetDataOp>(op) &&
3663            "BodyGen requested for non TargetDataOp");
3664     auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
3665     Region &region = cast<omp::TargetDataOp>(op).getRegion();
3666     switch (bodyGenType) {
3667     case BodyGenTy::Priv:
3668       // Check if any device ptr/addr info is available
3669       if (!info.DevicePtrInfoMap.empty()) {
3670         builder.restoreIP(codeGenIP);
3671 
3672         mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3673                      blockArgIface.getUseDeviceAddrBlockArgs(),
3674                      useDeviceAddrVars, mapData,
3675                      [&](llvm::Value *basePointer) -> llvm::Value * {
3676                        if (!info.DevicePtrInfoMap[basePointer].second)
3677                          return nullptr;
3678                        return builder.CreateLoad(
3679                            builder.getPtrTy(),
3680                            info.DevicePtrInfoMap[basePointer].second);
3681                      });
3682         mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3683                      blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
3684                      mapData, [&](llvm::Value *basePointer) {
3685                        return info.DevicePtrInfoMap[basePointer].second;
3686                      });
3687 
3688         if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
3689                                            moduleTranslation)))
3690           return llvm::make_error<PreviouslyReportedError>();
3691       }
3692       break;
3693     case BodyGenTy::DupNoPriv:
3694       break;
3695     case BodyGenTy::NoPriv:
3696       // If device info is available then region has already been generated
3697       if (info.DevicePtrInfoMap.empty()) {
3698         builder.restoreIP(codeGenIP);
3699         // For device pass, if use_device_ptr(addr) mappings were present,
3700         // we need to link them here before codegen.
3701         if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
3702           mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3703                        blockArgIface.getUseDeviceAddrBlockArgs(),
3704                        useDeviceAddrVars, mapData);
3705           mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3706                        blockArgIface.getUseDevicePtrBlockArgs(),
3707                        useDevicePtrVars, mapData);
3708         }
3709 
3710         if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
3711                                            moduleTranslation)))
3712           return llvm::make_error<PreviouslyReportedError>();
3713       }
3714       break;
3715     }
3716     return builder.saveIP();
3717   };
3718 
3719   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3720   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3721       findAllocaInsertPoint(builder, moduleTranslation);
3722   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
3723     if (isa<omp::TargetDataOp>(op))
3724       return ompBuilder->createTargetData(
3725           ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID),
3726           ifCond, info, genMapInfoCB, nullptr, bodyGenCB);
3727     return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
3728                                         builder.getInt64(deviceID), ifCond,
3729                                         info, genMapInfoCB, &RTLFn);
3730   }();
3731 
3732   if (failed(handleError(afterIP, *op)))
3733     return failure();
3734 
3735   builder.restoreIP(*afterIP);
3736   return success();
3737 }
3738 
3739 /// Lowers the FlagsAttr which is applied to the module on the device
3740 /// pass when offloading, this attribute contains OpenMP RTL globals that can
3741 /// be passed as flags to the frontend, otherwise they are set to default
3742 LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
3743                                LLVM::ModuleTranslation &moduleTranslation) {
3744   if (!cast<mlir::ModuleOp>(op))
3745     return failure();
3746 
3747   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3748 
3749   ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
3750                               attribute.getOpenmpDeviceVersion());
3751 
3752   if (attribute.getNoGpuLib())
3753     return success();
3754 
3755   ompBuilder->createGlobalFlag(
3756       attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
3757       "__omp_rtl_debug_kind");
3758   ompBuilder->createGlobalFlag(
3759       attribute
3760           .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
3761       ,
3762       "__omp_rtl_assume_teams_oversubscription");
3763   ompBuilder->createGlobalFlag(
3764       attribute
3765           .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
3766       ,
3767       "__omp_rtl_assume_threads_oversubscription");
3768   ompBuilder->createGlobalFlag(
3769       attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
3770       "__omp_rtl_assume_no_thread_state");
3771   ompBuilder->createGlobalFlag(
3772       attribute
3773           .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
3774       ,
3775       "__omp_rtl_assume_no_nested_parallelism");
3776   return success();
3777 }
3778 
3779 static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
3780                                      omp::TargetOp targetOp,
3781                                      llvm::StringRef parentName = "") {
3782   auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
3783 
3784   assert(fileLoc && "No file found from location");
3785   StringRef fileName = fileLoc.getFilename().getValue();
3786 
3787   llvm::sys::fs::UniqueID id;
3788   if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
3789     targetOp.emitError("Unable to get unique ID for file");
3790     return false;
3791   }
3792 
3793   uint64_t line = fileLoc.getLine();
3794   targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
3795                                            id.getFile(), line);
3796   return true;
3797 }
3798 
3799 static void
3800 handleDeclareTargetMapVar(MapInfoData &mapData,
3801                           LLVM::ModuleTranslation &moduleTranslation,
3802                           llvm::IRBuilderBase &builder, llvm::Function *func) {
3803   for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3804     // In the case of declare target mapped variables, the basePointer is
3805     // the reference pointer generated by the convertDeclareTargetAttr
3806     // method. Whereas the kernelValue is the original variable, so for
3807     // the device we must replace all uses of this original global variable
3808     // (stored in kernelValue) with the reference pointer (stored in
3809     // basePointer for declare target mapped variables), as for device the
3810     // data is mapped into this reference pointer and should be loaded
3811     // from it, the original variable is discarded. On host both exist and
3812     // metadata is generated (elsewhere in the convertDeclareTargetAttr)
3813     // function to link the two variables in the runtime and then both the
3814     // reference pointer and the pointer are assigned in the kernel argument
3815     // structure for the host.
3816     if (mapData.IsDeclareTarget[i]) {
3817       // If the original map value is a constant, then we have to make sure all
3818       // of it's uses within the current kernel/function that we are going to
3819       // rewrite are converted to instructions, as we will be altering the old
3820       // use (OriginalValue) from a constant to an instruction, which will be
3821       // illegal and ICE the compiler if the user is a constant expression of
3822       // some kind e.g. a constant GEP.
3823       if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
3824         convertUsersOfConstantsToInstructions(constant, func, false);
3825 
3826       // The users iterator will get invalidated if we modify an element,
3827       // so we populate this vector of uses to alter each user on an
3828       // individual basis to emit its own load (rather than one load for
3829       // all).
3830       llvm::SmallVector<llvm::User *> userVec;
3831       for (llvm::User *user : mapData.OriginalValue[i]->users())
3832         userVec.push_back(user);
3833 
3834       for (llvm::User *user : userVec) {
3835         if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
3836           if (insn->getFunction() == func) {
3837             auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
3838                                             mapData.BasePointers[i]);
3839             load->moveBefore(insn->getIterator());
3840             user->replaceUsesOfWith(mapData.OriginalValue[i], load);
3841           }
3842         }
3843       }
3844     }
3845   }
3846 }
3847 
3848 // The createDeviceArgumentAccessor function generates
3849 // instructions for retrieving (acessing) kernel
3850 // arguments inside of the device kernel for use by
3851 // the kernel. This enables different semantics such as
3852 // the creation of temporary copies of data allowing
3853 // semantics like read-only/no host write back kernel
3854 // arguments.
3855 //
3856 // This currently implements a very light version of Clang's
3857 // EmitParmDecl's handling of direct argument handling as well
3858 // as a portion of the argument access generation based on
3859 // capture types found at the end of emitOutlinedFunctionPrologue
3860 // in Clang. The indirect path handling of EmitParmDecl's may be
3861 // required for future work, but a direct 1-to-1 copy doesn't seem
3862 // possible as the logic is rather scattered throughout Clang's
3863 // lowering and perhaps we wish to deviate slightly.
3864 //
3865 // \param mapData - A container containing vectors of information
3866 // corresponding to the input argument, which should have a
3867 // corresponding entry in the MapInfoData containers
3868 // OrigialValue's.
3869 // \param arg - This is the generated kernel function argument that
3870 // corresponds to the passed in input argument. We generated different
3871 // accesses of this Argument, based on capture type and other Input
3872 // related information.
3873 // \param input - This is the host side value that will be passed to
3874 // the kernel i.e. the kernel input, we rewrite all uses of this within
3875 // the kernel (as we generate the kernel body based on the target's region
3876 // which maintians references to the original input) to the retVal argument
3877 // apon exit of this function inside of the OMPIRBuilder. This interlinks
3878 // the kernel argument to future uses of it in the function providing
3879 // appropriate "glue" instructions inbetween.
3880 // \param retVal - This is the value that all uses of input inside of the
3881 // kernel will be re-written to, the goal of this function is to generate
3882 // an appropriate location for the kernel argument to be accessed from,
3883 // e.g. ByRef will result in a temporary allocation location and then
3884 // a store of the kernel argument into this allocated memory which
3885 // will then be loaded from, ByCopy will use the allocated memory
3886 // directly.
3887 static llvm::IRBuilderBase::InsertPoint
3888 createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3889                              llvm::Value *input, llvm::Value *&retVal,
3890                              llvm::IRBuilderBase &builder,
3891                              llvm::OpenMPIRBuilder &ompBuilder,
3892                              LLVM::ModuleTranslation &moduleTranslation,
3893                              llvm::IRBuilderBase::InsertPoint allocaIP,
3894                              llvm::IRBuilderBase::InsertPoint codeGenIP) {
3895   builder.restoreIP(allocaIP);
3896 
3897   omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
3898 
3899   // Find the associated MapInfoData entry for the current input
3900   for (size_t i = 0; i < mapData.MapClause.size(); ++i)
3901     if (mapData.OriginalValue[i] == input) {
3902       auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3903       capture =
3904           mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3905 
3906       break;
3907     }
3908 
3909   unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
3910   unsigned int defaultAS =
3911       ompBuilder.M.getDataLayout().getProgramAddressSpace();
3912 
3913   // Create the alloca for the argument the current point.
3914   llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
3915 
3916   if (allocaAS != defaultAS && arg.getType()->isPointerTy())
3917     v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
3918 
3919   builder.CreateStore(&arg, v);
3920 
3921   builder.restoreIP(codeGenIP);
3922 
3923   switch (capture) {
3924   case omp::VariableCaptureKind::ByCopy: {
3925     retVal = v;
3926     break;
3927   }
3928   case omp::VariableCaptureKind::ByRef: {
3929     retVal = builder.CreateAlignedLoad(
3930         v->getType(), v,
3931         ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
3932     break;
3933   }
3934   case omp::VariableCaptureKind::This:
3935   case omp::VariableCaptureKind::VLAType:
3936     // TODO: Consider returning error to use standard reporting for
3937     // unimplemented features.
3938     assert(false && "Currently unsupported capture kind");
3939     break;
3940   }
3941 
3942   return builder.saveIP();
3943 }
3944 
3945 /// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3946 /// operation and populate output variables with their corresponding host value
3947 /// (i.e. operand evaluated outside of the target region), based on their uses
3948 /// inside of the target region.
3949 ///
3950 /// Loop bounds and steps are only optionally populated, if output vectors are
3951 /// provided.
3952 static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
3953                                    Value &numTeamsLower, Value &numTeamsUpper,
3954                                    Value &threadLimit) {
3955   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3956   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
3957                                    blockArgIface.getHostEvalBlockArgs())) {
3958     Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
3959 
3960     for (Operation *user : blockArg.getUsers()) {
3961       llvm::TypeSwitch<Operation *>(user)
3962           .Case([&](omp::TeamsOp teamsOp) {
3963             if (teamsOp.getNumTeamsLower() == blockArg)
3964               numTeamsLower = hostEvalVar;
3965             else if (teamsOp.getNumTeamsUpper() == blockArg)
3966               numTeamsUpper = hostEvalVar;
3967             else if (teamsOp.getThreadLimit() == blockArg)
3968               threadLimit = hostEvalVar;
3969             else
3970               llvm_unreachable("unsupported host_eval use");
3971           })
3972           .Case([&](omp::ParallelOp parallelOp) {
3973             if (parallelOp.getNumThreads() == blockArg)
3974               numThreads = hostEvalVar;
3975             else
3976               llvm_unreachable("unsupported host_eval use");
3977           })
3978           .Case([&](omp::LoopNestOp loopOp) {
3979             // TODO: Extract bounds and step values. Currently, this cannot be
3980             // reached because translation would have been stopped earlier as a
3981             // result of `checkImplementationStatus` detecting and reporting
3982             // this situation.
3983             llvm_unreachable("unsupported host_eval use");
3984           })
3985           .Default([](Operation *) {
3986             llvm_unreachable("unsupported host_eval use");
3987           });
3988     }
3989   }
3990 }
3991 
3992 /// If \p op is of the given type parameter, return it casted to that type.
3993 /// Otherwise, if its immediate parent operation (or some other higher-level
3994 /// parent, if \p immediateParent is false) is of that type, return that parent
3995 /// casted to the given type.
3996 ///
3997 /// If \p op is \c null or neither it or its parent(s) are of the specified
3998 /// type, return a \c null operation.
3999 template <typename OpTy>
4000 static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
4001   if (!op)
4002     return OpTy();
4003 
4004   if (OpTy casted = dyn_cast<OpTy>(op))
4005     return casted;
4006 
4007   if (immediateParent)
4008     return dyn_cast_if_present<OpTy>(op->getParentOp());
4009 
4010   return op->getParentOfType<OpTy>();
4011 }
4012 
4013 /// If the given \p value is defined by an \c llvm.mlir.constant operation and
4014 /// it is of an integer type, return its value.
4015 static std::optional<int64_t> extractConstInteger(Value value) {
4016   if (!value)
4017     return std::nullopt;
4018 
4019   if (auto constOp =
4020           dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
4021     if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4022       return constAttr.getInt();
4023 
4024   return std::nullopt;
4025 }
4026 
4027 /// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
4028 /// values as stated by the corresponding clauses, if constant.
4029 ///
4030 /// These default values must be set before the creation of the outlined LLVM
4031 /// function for the target region, so that they can be used to initialize the
4032 /// corresponding global `ConfigurationEnvironmentTy` structure.
4033 static void
4034 initTargetDefaultAttrs(omp::TargetOp targetOp,
4035                        llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
4036                        bool isTargetDevice) {
4037   // TODO: Handle constant 'if' clauses.
4038   Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
4039 
4040   Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
4041   if (!isTargetDevice) {
4042     extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4043                            threadLimit);
4044   } else {
4045     // In the target device, values for these clauses are not passed as
4046     // host_eval, but instead evaluated prior to entry to the region. This
4047     // ensures values are mapped and available inside of the target region.
4048     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4049       numTeamsLower = teamsOp.getNumTeamsLower();
4050       numTeamsUpper = teamsOp.getNumTeamsUpper();
4051       threadLimit = teamsOp.getThreadLimit();
4052     }
4053 
4054     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4055       numThreads = parallelOp.getNumThreads();
4056   }
4057 
4058   // Handle clauses impacting the number of teams.
4059 
4060   int32_t minTeamsVal = 1, maxTeamsVal = -1;
4061   if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4062     // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
4063     // clang and set min and max to the same value.
4064     if (numTeamsUpper) {
4065       if (auto val = extractConstInteger(numTeamsUpper))
4066         minTeamsVal = maxTeamsVal = *val;
4067     } else {
4068       minTeamsVal = maxTeamsVal = 0;
4069     }
4070   } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4071                                                     /*immediateParent=*/true) ||
4072              castOrGetParentOfType<omp::SimdOp>(capturedOp,
4073                                                 /*immediateParent=*/true)) {
4074     minTeamsVal = maxTeamsVal = 1;
4075   } else {
4076     minTeamsVal = maxTeamsVal = -1;
4077   }
4078 
4079   // Handle clauses impacting the number of threads.
4080 
4081   auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
4082     if (!clauseValue)
4083       return;
4084 
4085     if (auto val = extractConstInteger(clauseValue))
4086       result = *val;
4087 
4088     // Found an applicable clause, so it's not undefined. Mark as unknown
4089     // because it's not constant.
4090     if (result < 0)
4091       result = 0;
4092   };
4093 
4094   // Extract 'thread_limit' clause from 'target' and 'teams' directives.
4095   int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
4096   setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
4097   setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
4098 
4099   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
4100   int32_t maxThreadsVal = -1;
4101   if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4102     setMaxValueFromClause(numThreads, maxThreadsVal);
4103   else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4104                                               /*immediateParent=*/true))
4105     maxThreadsVal = 1;
4106 
4107   // For max values, < 0 means unset, == 0 means set but unknown. Select the
4108   // minimum value between 'max_threads' and 'thread_limit' clauses that were
4109   // set.
4110   int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4111   if (combinedMaxThreadsVal < 0 ||
4112       (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
4113     combinedMaxThreadsVal = teamsThreadLimitVal;
4114 
4115   if (combinedMaxThreadsVal < 0 ||
4116       (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
4117     combinedMaxThreadsVal = maxThreadsVal;
4118 
4119   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4120   attrs.MinTeams = minTeamsVal;
4121   attrs.MaxTeams.front() = maxTeamsVal;
4122   attrs.MinThreads = 1;
4123   attrs.MaxThreads.front() = combinedMaxThreadsVal;
4124 }
4125 
4126 /// Gather LLVM runtime values for all clauses evaluated in the host that are
4127 /// passed to the kernel invocation.
4128 ///
4129 /// This function must be called only when compiling for the host. Also, it will
4130 /// only provide correct results if it's called after the body of \c targetOp
4131 /// has been fully generated.
4132 static void
4133 initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4134                        LLVM::ModuleTranslation &moduleTranslation,
4135                        omp::TargetOp targetOp,
4136                        llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4137   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4138   extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4139                          teamsThreadLimit);
4140 
4141   // TODO: Handle constant 'if' clauses.
4142   if (Value targetThreadLimit = targetOp.getThreadLimit())
4143     attrs.TargetThreadLimit.front() =
4144         moduleTranslation.lookupValue(targetThreadLimit);
4145 
4146   if (numTeamsLower)
4147     attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
4148 
4149   if (numTeamsUpper)
4150     attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
4151 
4152   if (teamsThreadLimit)
4153     attrs.TeamsThreadLimit.front() =
4154         moduleTranslation.lookupValue(teamsThreadLimit);
4155 
4156   if (numThreads)
4157     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
4158 
4159   // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4160 }
4161 
4162 static LogicalResult
4163 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4164                  LLVM::ModuleTranslation &moduleTranslation) {
4165   auto targetOp = cast<omp::TargetOp>(opInst);
4166   if (failed(checkImplementationStatus(opInst)))
4167     return failure();
4168 
4169   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4170   bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4171 
4172   auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
4173   auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
4174   auto &targetRegion = targetOp.getRegion();
4175   // Holds the private vars that have been mapped along with the block argument
4176   // that corresponds to the MapInfoOp corresponding to the private var in
4177   // question. So, for instance:
4178   //
4179   // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
4180   // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
4181   //
4182   // Then, %10 has been created so that the descriptor can be used by the
4183   // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
4184   // %arg0} in the mappedPrivateVars map.
4185   llvm::DenseMap<Value, Value> mappedPrivateVars;
4186   DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
4187   SmallVector<Value> mapVars = targetOp.getMapVars();
4188   ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
4189   llvm::Function *llvmOutlinedFn = nullptr;
4190 
4191   // TODO: It can also be false if a compile-time constant `false` IF clause is
4192   // specified.
4193   bool isOffloadEntry =
4194       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4195 
4196   // For some private variables, the MapsForPrivatizedVariablesPass
4197   // creates MapInfoOp instances. Go through the private variables and
4198   // the mapped variables so that during codegeneration we are able
4199   // to quickly look up the corresponding map variable, if any for each
4200   // private variable.
4201   if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
4202     OperandRange privateVars = targetOp.getPrivateVars();
4203     std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
4204     std::optional<DenseI64ArrayAttr> privateMapIndices =
4205         targetOp.getPrivateMapsAttr();
4206 
4207     for (auto [privVarIdx, privVarSymPair] :
4208          llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
4209       auto privVar = std::get<0>(privVarSymPair);
4210       auto privSym = std::get<1>(privVarSymPair);
4211 
4212       SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
4213       omp::PrivateClauseOp privatizer =
4214           findPrivatizer(targetOp, privatizerName);
4215 
4216       if (!privatizer.needsMap())
4217         continue;
4218 
4219       mlir::Value mappedValue =
4220           targetOp.getMappedValueForPrivateVar(privVarIdx);
4221       assert(mappedValue && "Expected to find mapped value for a privatized "
4222                             "variable that needs mapping");
4223 
4224       // The MapInfoOp defining the map var isn't really needed later.
4225       // So, we don't store it in any datastructure. Instead, we just
4226       // do some sanity checks on it right now.
4227       auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
4228       [[maybe_unused]] Type varType = mapInfoOp.getVarType();
4229 
4230       // Check #1: Check that the type of the private variable matches
4231       // the type of the variable being mapped.
4232       if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
4233         assert(
4234             varType == privVar.getType() &&
4235             "Type of private var doesn't match the type of the mapped value");
4236 
4237       // Ok, only 1 sanity check for now.
4238       // Record the block argument corresponding to this mapvar.
4239       mappedPrivateVars.insert(
4240           {privVar,
4241            targetRegion.getArgument(argIface.getMapBlockArgsStart() +
4242                                     (*privateMapIndices)[privVarIdx])});
4243     }
4244   }
4245 
4246   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4247   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
4248       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4249     // Forward target-cpu and target-features function attributes from the
4250     // original function to the new outlined function.
4251     llvm::Function *llvmParentFn =
4252         moduleTranslation.lookupFunction(parentFn.getName());
4253     llvmOutlinedFn = codeGenIP.getBlock()->getParent();
4254     assert(llvmParentFn && llvmOutlinedFn &&
4255            "Both parent and outlined functions must exist at this point");
4256 
4257     if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
4258         attr.isStringAttribute())
4259       llvmOutlinedFn->addFnAttr(attr);
4260 
4261     if (auto attr = llvmParentFn->getFnAttribute("target-features");
4262         attr.isStringAttribute())
4263       llvmOutlinedFn->addFnAttr(attr);
4264 
4265     for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
4266       auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
4267       llvm::Value *mapOpValue =
4268           moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
4269       moduleTranslation.mapValue(arg, mapOpValue);
4270     }
4271 
4272     // Do privatization after moduleTranslation has already recorded
4273     // mapped values.
4274     MutableArrayRef<BlockArgument> privateBlockArgs =
4275         argIface.getPrivateBlockArgs();
4276     SmallVector<mlir::Value> mlirPrivateVars;
4277     SmallVector<llvm::Value *> llvmPrivateVars;
4278     SmallVector<omp::PrivateClauseOp> privateDecls;
4279     mlirPrivateVars.reserve(privateBlockArgs.size());
4280     llvmPrivateVars.reserve(privateBlockArgs.size());
4281     collectPrivatizationDecls(targetOp, privateDecls);
4282     for (mlir::Value privateVar : targetOp.getPrivateVars())
4283       mlirPrivateVars.push_back(privateVar);
4284 
4285     llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
4286         builder, moduleTranslation, privateBlockArgs, privateDecls,
4287         mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
4288 
4289     if (failed(handleError(afterAllocas, *targetOp)))
4290       return llvm::make_error<PreviouslyReportedError>();
4291 
4292     SmallVector<Region *> privateCleanupRegions;
4293     llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
4294                     [](omp::PrivateClauseOp privatizer) {
4295                       return &privatizer.getDeallocRegion();
4296                     });
4297 
4298     builder.restoreIP(codeGenIP);
4299     llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
4300         targetRegion, "omp.target", builder, moduleTranslation);
4301 
4302     if (!exitBlock)
4303       return exitBlock.takeError();
4304 
4305     builder.SetInsertPoint(*exitBlock);
4306     if (!privateCleanupRegions.empty()) {
4307       if (failed(inlineOmpRegionCleanup(
4308               privateCleanupRegions, llvmPrivateVars, moduleTranslation,
4309               builder, "omp.targetop.private.cleanup",
4310               /*shouldLoadCleanupRegionArg=*/false))) {
4311         return llvm::createStringError(
4312             "failed to inline `dealloc` region of `omp.private` "
4313             "op in the target region");
4314       }
4315     }
4316 
4317     return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
4318   };
4319 
4320   StringRef parentName = parentFn.getName();
4321 
4322   llvm::TargetRegionEntryInfo entryInfo;
4323 
4324   if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
4325     return failure();
4326 
4327   MapInfoData mapData;
4328   collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
4329                                 builder);
4330 
4331   llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
4332   auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
4333       -> llvm::OpenMPIRBuilder::MapInfosTy & {
4334     builder.restoreIP(codeGenIP);
4335     genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
4336     return combinedInfos;
4337   };
4338 
4339   auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
4340                            llvm::Value *&retVal, InsertPointTy allocaIP,
4341                            InsertPointTy codeGenIP)
4342       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4343     // We just return the unaltered argument for the host function
4344     // for now, some alterations may be required in the future to
4345     // keep host fallback functions working identically to the device
4346     // version (e.g. pass ByCopy values should be treated as such on
4347     // host and device, currently not always the case)
4348     if (!isTargetDevice) {
4349       retVal = cast<llvm::Value>(&arg);
4350       return codeGenIP;
4351     }
4352 
4353     return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
4354                                         *ompBuilder, moduleTranslation,
4355                                         allocaIP, codeGenIP);
4356   };
4357 
4358   llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4359   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4360   initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);
4361 
4362   // Collect host-evaluated values needed to properly launch the kernel from the
4363   // host.
4364   if (!isTargetDevice)
4365     initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);
4366 
4367   // Pass host-evaluated values as parameters to the kernel / host fallback,
4368   // except if they are constants. In any case, map the MLIR block argument to
4369   // the corresponding LLVM values.
4370   llvm::SmallVector<llvm::Value *, 4> kernelInput;
4371   SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
4372   ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
4373   for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
4374     llvm::Value *value = moduleTranslation.lookupValue(var);
4375     moduleTranslation.mapValue(arg, value);
4376 
4377     if (!llvm::isa<llvm::Constant>(value))
4378       kernelInput.push_back(value);
4379   }
4380 
4381   for (size_t i = 0; i < mapVars.size(); ++i) {
4382     // declare target arguments are not passed to kernels as arguments
4383     // TODO: We currently do not handle cases where a member is explicitly
4384     // passed in as an argument, this will likley need to be handled in
4385     // the near future, rather than using IsAMember, it may be better to
4386     // test if the relevant BlockArg is used within the target region and
4387     // then use that as a basis for exclusion in the kernel inputs.
4388     if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
4389       kernelInput.push_back(mapData.OriginalValue[i]);
4390   }
4391 
4392   SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
4393   buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
4394                   moduleTranslation, dds);
4395 
4396   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4397       findAllocaInsertPoint(builder, moduleTranslation);
4398   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4399 
4400   llvm::Value *ifCond = nullptr;
4401   if (Value targetIfCond = targetOp.getIfExpr())
4402     ifCond = moduleTranslation.lookupValue(targetIfCond);
4403 
4404   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4405       moduleTranslation.getOpenMPBuilder()->createTarget(
4406           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4407           defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
4408           argAccessorCB, dds, targetOp.getNowait());
4409 
4410   if (failed(handleError(afterIP, opInst)))
4411     return failure();
4412 
4413   builder.restoreIP(*afterIP);
4414 
4415   // Remap access operations to declare target reference pointers for the
4416   // device, essentially generating extra loadop's as necessary
4417   if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
4418     handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
4419                               llvmOutlinedFn);
4420 
4421   return success();
4422 }
4423 
4424 static LogicalResult
4425 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
4426                          LLVM::ModuleTranslation &moduleTranslation) {
4427   // Amend omp.declare_target by deleting the IR of the outlined functions
4428   // created for target regions. They cannot be filtered out from MLIR earlier
4429   // because the omp.target operation inside must be translated to LLVM, but
4430   // the wrapper functions themselves must not remain at the end of the
4431   // process. We know that functions where omp.declare_target does not match
4432   // omp.is_target_device at this stage can only be wrapper functions because
4433   // those that aren't are removed earlier as an MLIR transformation pass.
4434   if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
4435     if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
4436             op->getParentOfType<ModuleOp>().getOperation())) {
4437       if (!offloadMod.getIsTargetDevice())
4438         return success();
4439 
4440       omp::DeclareTargetDeviceType declareType =
4441           attribute.getDeviceType().getValue();
4442 
4443       if (declareType == omp::DeclareTargetDeviceType::host) {
4444         llvm::Function *llvmFunc =
4445             moduleTranslation.lookupFunction(funcOp.getName());
4446         llvmFunc->dropAllReferences();
4447         llvmFunc->eraseFromParent();
4448       }
4449     }
4450     return success();
4451   }
4452 
4453   if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
4454     llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
4455     if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
4456       llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4457       bool isDeclaration = gOp.isDeclaration();
4458       bool isExternallyVisible =
4459           gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
4460       auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
4461       llvm::StringRef mangledName = gOp.getSymName();
4462       auto captureClause =
4463           convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
4464       auto deviceClause =
4465           convertToDeviceClauseKind(attribute.getDeviceType().getValue());
4466       // unused for MLIR at the moment, required in Clang for book
4467       // keeping
4468       std::vector<llvm::GlobalVariable *> generatedRefs;
4469 
4470       std::vector<llvm::Triple> targetTriple;
4471       auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
4472           op->getParentOfType<mlir::ModuleOp>()->getAttr(
4473               LLVM::LLVMDialect::getTargetTripleAttrName()));
4474       if (targetTripleAttr)
4475         targetTriple.emplace_back(targetTripleAttr.data());
4476 
4477       auto fileInfoCallBack = [&loc]() {
4478         std::string filename = "";
4479         std::uint64_t lineNo = 0;
4480 
4481         if (loc) {
4482           filename = loc.getFilename().str();
4483           lineNo = loc.getLine();
4484         }
4485 
4486         return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
4487                                                      lineNo);
4488       };
4489 
4490       ompBuilder->registerTargetGlobalVariable(
4491           captureClause, deviceClause, isDeclaration, isExternallyVisible,
4492           ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4493           generatedRefs, /*OpenMPSimd*/ false, targetTriple,
4494           /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
4495           gVal->getType(), gVal);
4496 
4497       if (ompBuilder->Config.isTargetDevice() &&
4498           (attribute.getCaptureClause().getValue() !=
4499                mlir::omp::DeclareTargetCaptureClause::to ||
4500            ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4501         ompBuilder->getAddrOfDeclareTargetVar(
4502             captureClause, deviceClause, isDeclaration, isExternallyVisible,
4503             ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4504             generatedRefs, /*OpenMPSimd*/ false, targetTriple, gVal->getType(),
4505             /*GlobalInitializer*/ nullptr,
4506             /*VariableLinkage*/ nullptr);
4507       }
4508     }
4509   }
4510 
4511   return success();
4512 }
4513 
4514 // Returns true if the operation is inside a TargetOp or
4515 // is part of a declare target function.
4516 static bool isTargetDeviceOp(Operation *op) {
4517   // Assumes no reverse offloading
4518   if (op->getParentOfType<omp::TargetOp>())
4519     return true;
4520 
4521   // Certain operations return results, and whether utilised in host or
4522   // target there is a chance an LLVM Dialect operation depends on it
4523   // by taking it in as an operand, so we must always lower these in
4524   // some manner or result in an ICE (whether they end up in a no-op
4525   // or otherwise).
4526   if (mlir::isa<omp::ThreadprivateOp>(op))
4527     return true;
4528 
4529   if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
4530     if (auto declareTargetIface =
4531             llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
4532                 parentFn.getOperation()))
4533       if (declareTargetIface.isDeclareTarget() &&
4534           declareTargetIface.getDeclareTargetDeviceType() !=
4535               mlir::omp::DeclareTargetDeviceType::host)
4536         return true;
4537 
4538   return false;
4539 }
4540 
4541 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
4542 /// (including OpenMP runtime calls).
4543 static LogicalResult
4544 convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4545                              LLVM::ModuleTranslation &moduleTranslation) {
4546 
4547   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4548 
4549   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
4550       .Case([&](omp::BarrierOp op) -> LogicalResult {
4551         if (failed(checkImplementationStatus(*op)))
4552           return failure();
4553 
4554         llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4555             ompBuilder->createBarrier(builder.saveIP(),
4556                                       llvm::omp::OMPD_barrier);
4557         return handleError(afterIP, *op);
4558       })
4559       .Case([&](omp::TaskyieldOp op) {
4560         if (failed(checkImplementationStatus(*op)))
4561           return failure();
4562 
4563         ompBuilder->createTaskyield(builder.saveIP());
4564         return success();
4565       })
4566       .Case([&](omp::FlushOp op) {
4567         if (failed(checkImplementationStatus(*op)))
4568           return failure();
4569 
4570         // No support in Openmp runtime function (__kmpc_flush) to accept
4571         // the argument list.
4572         // OpenMP standard states the following:
4573         //  "An implementation may implement a flush with a list by ignoring
4574         //   the list, and treating it the same as a flush without a list."
4575         //
4576         // The argument list is discarded so that, flush with a list is treated
4577         // same as a flush without a list.
4578         ompBuilder->createFlush(builder.saveIP());
4579         return success();
4580       })
4581       .Case([&](omp::ParallelOp op) {
4582         return convertOmpParallel(op, builder, moduleTranslation);
4583       })
4584       .Case([&](omp::MaskedOp) {
4585         return convertOmpMasked(*op, builder, moduleTranslation);
4586       })
4587       .Case([&](omp::MasterOp) {
4588         return convertOmpMaster(*op, builder, moduleTranslation);
4589       })
4590       .Case([&](omp::CriticalOp) {
4591         return convertOmpCritical(*op, builder, moduleTranslation);
4592       })
4593       .Case([&](omp::OrderedRegionOp) {
4594         return convertOmpOrderedRegion(*op, builder, moduleTranslation);
4595       })
4596       .Case([&](omp::OrderedOp) {
4597         return convertOmpOrdered(*op, builder, moduleTranslation);
4598       })
4599       .Case([&](omp::WsloopOp) {
4600         return convertOmpWsloop(*op, builder, moduleTranslation);
4601       })
4602       .Case([&](omp::SimdOp) {
4603         return convertOmpSimd(*op, builder, moduleTranslation);
4604       })
4605       .Case([&](omp::AtomicReadOp) {
4606         return convertOmpAtomicRead(*op, builder, moduleTranslation);
4607       })
4608       .Case([&](omp::AtomicWriteOp) {
4609         return convertOmpAtomicWrite(*op, builder, moduleTranslation);
4610       })
4611       .Case([&](omp::AtomicUpdateOp op) {
4612         return convertOmpAtomicUpdate(op, builder, moduleTranslation);
4613       })
4614       .Case([&](omp::AtomicCaptureOp op) {
4615         return convertOmpAtomicCapture(op, builder, moduleTranslation);
4616       })
4617       .Case([&](omp::SectionsOp) {
4618         return convertOmpSections(*op, builder, moduleTranslation);
4619       })
4620       .Case([&](omp::SingleOp op) {
4621         return convertOmpSingle(op, builder, moduleTranslation);
4622       })
4623       .Case([&](omp::TeamsOp op) {
4624         return convertOmpTeams(op, builder, moduleTranslation);
4625       })
4626       .Case([&](omp::TaskOp op) {
4627         return convertOmpTaskOp(op, builder, moduleTranslation);
4628       })
4629       .Case([&](omp::TaskgroupOp op) {
4630         return convertOmpTaskgroupOp(op, builder, moduleTranslation);
4631       })
4632       .Case([&](omp::TaskwaitOp op) {
4633         return convertOmpTaskwaitOp(op, builder, moduleTranslation);
4634       })
4635       .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
4636             omp::CriticalDeclareOp>([](auto op) {
4637         // `yield` and `terminator` can be just omitted. The block structure
4638         // was created in the region that handles their parent operation.
4639         // `declare_reduction` will be used by reductions and is not
4640         // converted directly, skip it.
4641         // `critical.declare` is only used to declare names of critical
4642         // sections which will be used by `critical` ops and hence can be
4643         // ignored for lowering. The OpenMP IRBuilder will create unique
4644         // name for critical section names.
4645         return success();
4646       })
4647       .Case([&](omp::ThreadprivateOp) {
4648         return convertOmpThreadprivate(*op, builder, moduleTranslation);
4649       })
4650       .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
4651             omp::TargetUpdateOp>([&](auto op) {
4652         return convertOmpTargetData(op, builder, moduleTranslation);
4653       })
4654       .Case([&](omp::TargetOp) {
4655         return convertOmpTarget(*op, builder, moduleTranslation);
4656       })
4657       .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
4658           [&](auto op) {
4659             // No-op, should be handled by relevant owning operations e.g.
4660             // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
4661             // and then discarded
4662             return success();
4663           })
4664       .Default([&](Operation *inst) {
4665         return inst->emitError() << "not yet implemented: " << inst->getName();
4666       });
4667 }
4668 
4669 static LogicalResult
4670 convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
4671                       LLVM::ModuleTranslation &moduleTranslation) {
4672   return convertHostOrTargetOperation(op, builder, moduleTranslation);
4673 }
4674 
4675 static LogicalResult
4676 convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
4677                        LLVM::ModuleTranslation &moduleTranslation) {
4678   if (isa<omp::TargetOp>(op))
4679     return convertOmpTarget(*op, builder, moduleTranslation);
4680   if (isa<omp::TargetDataOp>(op))
4681     return convertOmpTargetData(op, builder, moduleTranslation);
4682   bool interrupted =
4683       op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
4684           if (isa<omp::TargetOp>(oper)) {
4685             if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
4686               return WalkResult::interrupt();
4687             return WalkResult::skip();
4688           }
4689           if (isa<omp::TargetDataOp>(oper)) {
4690             if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
4691               return WalkResult::interrupt();
4692             return WalkResult::skip();
4693           }
4694           return WalkResult::advance();
4695         }).wasInterrupted();
4696   return failure(interrupted);
4697 }
4698 
4699 namespace {
4700 
4701 /// Implementation of the dialect interface that converts operations belonging
4702 /// to the OpenMP dialect to LLVM IR.
4703 class OpenMPDialectLLVMIRTranslationInterface
4704     : public LLVMTranslationDialectInterface {
4705 public:
4706   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
4707 
4708   /// Translates the given operation to LLVM IR using the provided IR builder
4709   /// and saving the state in `moduleTranslation`.
4710   LogicalResult
4711   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
4712                    LLVM::ModuleTranslation &moduleTranslation) const final;
4713 
4714   /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
4715   /// runtime calls, or operation amendments
4716   LogicalResult
4717   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
4718                  NamedAttribute attribute,
4719                  LLVM::ModuleTranslation &moduleTranslation) const final;
4720 };
4721 
4722 } // namespace
4723 
4724 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
4725     Operation *op, ArrayRef<llvm::Instruction *> instructions,
4726     NamedAttribute attribute,
4727     LLVM::ModuleTranslation &moduleTranslation) const {
4728   return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
4729              attribute.getName())
4730       .Case("omp.is_target_device",
4731             [&](Attribute attr) {
4732               if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
4733                 llvm::OpenMPIRBuilderConfig &config =
4734                     moduleTranslation.getOpenMPBuilder()->Config;
4735                 config.setIsTargetDevice(deviceAttr.getValue());
4736                 return success();
4737               }
4738               return failure();
4739             })
4740       .Case("omp.is_gpu",
4741             [&](Attribute attr) {
4742               if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
4743                 llvm::OpenMPIRBuilderConfig &config =
4744                     moduleTranslation.getOpenMPBuilder()->Config;
4745                 config.setIsGPU(gpuAttr.getValue());
4746                 return success();
4747               }
4748               return failure();
4749             })
4750       .Case("omp.host_ir_filepath",
4751             [&](Attribute attr) {
4752               if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
4753                 llvm::OpenMPIRBuilder *ompBuilder =
4754                     moduleTranslation.getOpenMPBuilder();
4755                 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
4756                 return success();
4757               }
4758               return failure();
4759             })
4760       .Case("omp.flags",
4761             [&](Attribute attr) {
4762               if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
4763                 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
4764               return failure();
4765             })
4766       .Case("omp.version",
4767             [&](Attribute attr) {
4768               if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
4769                 llvm::OpenMPIRBuilder *ompBuilder =
4770                     moduleTranslation.getOpenMPBuilder();
4771                 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
4772                                             versionAttr.getVersion());
4773                 return success();
4774               }
4775               return failure();
4776             })
4777       .Case("omp.declare_target",
4778             [&](Attribute attr) {
4779               if (auto declareTargetAttr =
4780                       dyn_cast<omp::DeclareTargetAttr>(attr))
4781                 return convertDeclareTargetAttr(op, declareTargetAttr,
4782                                                 moduleTranslation);
4783               return failure();
4784             })
4785       .Case("omp.requires",
4786             [&](Attribute attr) {
4787               if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
4788                 using Requires = omp::ClauseRequires;
4789                 Requires flags = requiresAttr.getValue();
4790                 llvm::OpenMPIRBuilderConfig &config =
4791                     moduleTranslation.getOpenMPBuilder()->Config;
4792                 config.setHasRequiresReverseOffload(
4793                     bitEnumContainsAll(flags, Requires::reverse_offload));
4794                 config.setHasRequiresUnifiedAddress(
4795                     bitEnumContainsAll(flags, Requires::unified_address));
4796                 config.setHasRequiresUnifiedSharedMemory(
4797                     bitEnumContainsAll(flags, Requires::unified_shared_memory));
4798                 config.setHasRequiresDynamicAllocators(
4799                     bitEnumContainsAll(flags, Requires::dynamic_allocators));
4800                 return success();
4801               }
4802               return failure();
4803             })
4804       .Case("omp.target_triples",
4805             [&](Attribute attr) {
4806               if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
4807                 llvm::OpenMPIRBuilderConfig &config =
4808                     moduleTranslation.getOpenMPBuilder()->Config;
4809                 config.TargetTriples.clear();
4810                 config.TargetTriples.reserve(triplesAttr.size());
4811                 for (Attribute tripleAttr : triplesAttr) {
4812                   if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
4813                     config.TargetTriples.emplace_back(tripleStrAttr.getValue());
4814                   else
4815                     return failure();
4816                 }
4817                 return success();
4818               }
4819               return failure();
4820             })
4821       .Default([](Attribute) {
4822         // Fall through for omp attributes that do not require lowering.
4823         return success();
4824       })(attribute.getValue());
4825 
4826   return failure();
4827 }
4828 
4829 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
4830 /// (including OpenMP runtime calls).
4831 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
4832     Operation *op, llvm::IRBuilderBase &builder,
4833     LLVM::ModuleTranslation &moduleTranslation) const {
4834 
4835   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4836   if (ompBuilder->Config.isTargetDevice()) {
4837     if (isTargetDeviceOp(op)) {
4838       return convertTargetDeviceOp(op, builder, moduleTranslation);
4839     } else {
4840       return convertTargetOpsInNest(op, builder, moduleTranslation);
4841     }
4842   }
4843   return convertHostOrTargetOperation(op, builder, moduleTranslation);
4844 }
4845 
4846 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
4847   registry.insert<omp::OpenMPDialect>();
4848   registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
4849     dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
4850   });
4851 }
4852 
4853 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
4854   DialectRegistry registry;
4855   registerOpenMPDialectTranslation(registry);
4856   context.appendDialectRegistry(registry);
4857 }
4858