xref: /llvm-project/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (revision d87964de78ce692fd132ea453c32e4435309a306)
1 //===- llvm/unittest/IR/OpenMPIRBuilderTest.cpp - OpenMPIRBuilder tests ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Frontend/OpenMP/OMPConstants.h"
10 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
11 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/DIBuilder.h"
14 #include "llvm/IR/Function.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/LLVMContext.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Verifier.h"
20 #include "llvm/Passes/PassBuilder.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include <optional>
26 
27 using namespace llvm;
28 using namespace omp;
29 
30 // Wrapper lambdas to allow using EXPECT*() macros inside of error-returning
31 // callbacks.
32 #define FINICB_WRAPPER(cb)                                                     \
33   [&cb](InsertPointTy IP) -> Error {                                           \
34     cb(IP);                                                                    \
35     return Error::success();                                                   \
36   }
37 
38 #define BODYGENCB_WRAPPER(cb)                                                  \
39   [&cb](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) -> Error {            \
40     cb(AllocaIP, CodeGenIP);                                                   \
41     return Error::success();                                                   \
42   }
43 
44 namespace {
45 
46 /// Create an instruction that uses the values in \p Values. We use "printf"
47 /// just because it is often used for this purpose in test code, but it is never
48 /// executed here.
49 static CallInst *createPrintfCall(IRBuilder<> &Builder, StringRef FormatStr,
50                                   ArrayRef<Value *> Values) {
51   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
52 
53   GlobalVariable *GV = Builder.CreateGlobalString(FormatStr, "", 0, M);
54   Constant *Zero = ConstantInt::get(Type::getInt32Ty(M->getContext()), 0);
55   Constant *Indices[] = {Zero, Zero};
56   Constant *FormatStrConst =
57       ConstantExpr::getInBoundsGetElementPtr(GV->getValueType(), GV, Indices);
58 
59   Function *PrintfDecl = M->getFunction("printf");
60   if (!PrintfDecl) {
61     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
62     FunctionType *Ty = FunctionType::get(Builder.getInt32Ty(), true);
63     PrintfDecl = Function::Create(Ty, Linkage, "printf", M);
64   }
65 
66   SmallVector<Value *, 4> Args;
67   Args.push_back(FormatStrConst);
68   Args.append(Values.begin(), Values.end());
69   return Builder.CreateCall(PrintfDecl, Args);
70 }
71 
72 /// Verify that blocks in \p RefOrder are corresponds to the depth-first visit
73 /// order the control flow of \p F.
74 ///
75 /// This is an easy way to verify the branching structure of the CFG without
76 /// checking every branch instruction individually. For the CFG of a
77 /// CanonicalLoopInfo, the Cond BB's terminating branch's first edge is entering
78 /// the body, i.e. the DFS order corresponds to the execution order with one
79 /// loop iteration.
80 static testing::AssertionResult
81 verifyDFSOrder(Function *F, ArrayRef<BasicBlock *> RefOrder) {
82   ArrayRef<BasicBlock *>::iterator It = RefOrder.begin();
83   ArrayRef<BasicBlock *>::iterator E = RefOrder.end();
84 
85   df_iterator_default_set<BasicBlock *, 16> Visited;
86   auto DFS = llvm::depth_first_ext(&F->getEntryBlock(), Visited);
87 
88   BasicBlock *Prev = nullptr;
89   for (BasicBlock *BB : DFS) {
90     if (It != E && BB == *It) {
91       Prev = *It;
92       ++It;
93     }
94   }
95 
96   if (It == E)
97     return testing::AssertionSuccess();
98   if (!Prev)
99     return testing::AssertionFailure()
100            << "Did not find " << (*It)->getName() << " in control flow";
101   return testing::AssertionFailure()
102          << "Expected " << Prev->getName() << " before " << (*It)->getName()
103          << " in control flow";
104 }
105 
106 /// Verify that blocks in \p RefOrder are in the same relative order in the
107 /// linked lists of blocks in \p F. The linked list may contain additional
108 /// blocks in-between.
109 ///
110 /// While the order in the linked list is not relevant for semantics, keeping
111 /// the order roughly in execution order makes its printout easier to read.
112 static testing::AssertionResult
113 verifyListOrder(Function *F, ArrayRef<BasicBlock *> RefOrder) {
114   ArrayRef<BasicBlock *>::iterator It = RefOrder.begin();
115   ArrayRef<BasicBlock *>::iterator E = RefOrder.end();
116 
117   BasicBlock *Prev = nullptr;
118   for (BasicBlock &BB : *F) {
119     if (It != E && &BB == *It) {
120       Prev = *It;
121       ++It;
122     }
123   }
124 
125   if (It == E)
126     return testing::AssertionSuccess();
127   if (!Prev)
128     return testing::AssertionFailure() << "Did not find " << (*It)->getName()
129                                        << " in function " << F->getName();
130   return testing::AssertionFailure()
131          << "Expected " << Prev->getName() << " before " << (*It)->getName()
132          << " in function " << F->getName();
133 }
134 
135 /// Populate Calls with call instructions calling the function with the given
136 /// FnID from the given function F.
137 static void findCalls(Function *F, omp::RuntimeFunction FnID,
138                       OpenMPIRBuilder &OMPBuilder,
139                       SmallVectorImpl<CallInst *> &Calls) {
140   Function *Fn = OMPBuilder.getOrCreateRuntimeFunctionPtr(FnID);
141   for (BasicBlock &BB : *F) {
142     for (Instruction &I : BB) {
143       auto *Call = dyn_cast<CallInst>(&I);
144       if (Call && Call->getCalledFunction() == Fn)
145         Calls.push_back(Call);
146     }
147   }
148 }
149 
150 /// Assuming \p F contains only one call to the function with the given \p FnID,
151 /// return that call.
152 static CallInst *findSingleCall(Function *F, omp::RuntimeFunction FnID,
153                                 OpenMPIRBuilder &OMPBuilder) {
154   SmallVector<CallInst *, 1> Calls;
155   findCalls(F, FnID, OMPBuilder, Calls);
156   EXPECT_EQ(1u, Calls.size());
157   if (Calls.size() != 1)
158     return nullptr;
159   return Calls.front();
160 }
161 
162 static omp::ScheduleKind getSchedKind(omp::OMPScheduleType SchedType) {
163   switch (SchedType & ~omp::OMPScheduleType::ModifierMask) {
164   case omp::OMPScheduleType::BaseDynamicChunked:
165     return omp::OMP_SCHEDULE_Dynamic;
166   case omp::OMPScheduleType::BaseGuidedChunked:
167     return omp::OMP_SCHEDULE_Guided;
168   case omp::OMPScheduleType::BaseAuto:
169     return omp::OMP_SCHEDULE_Auto;
170   case omp::OMPScheduleType::BaseRuntime:
171     return omp::OMP_SCHEDULE_Runtime;
172   default:
173     llvm_unreachable("unknown type for this test");
174   }
175 }
176 
177 class OpenMPIRBuilderTest : public testing::Test {
178 protected:
179   void SetUp() override {
180     M.reset(new Module("MyModule", Ctx));
181     FunctionType *FTy =
182         FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)},
183                           /*isVarArg=*/false);
184     F = Function::Create(FTy, Function::ExternalLinkage, "", M.get());
185     BB = BasicBlock::Create(Ctx, "", F);
186 
187     DIBuilder DIB(*M);
188     auto File = DIB.createFile("test.dbg", "/src", std::nullopt,
189                                std::optional<StringRef>("/src/test.dbg"));
190     auto CU =
191         DIB.createCompileUnit(dwarf::DW_LANG_C, File, "llvm-C", true, "", 0);
192     auto Type = DIB.createSubroutineType(DIB.getOrCreateTypeArray({}));
193     auto SP = DIB.createFunction(
194         CU, "foo", "", File, 1, Type, 1, DINode::FlagZero,
195         DISubprogram::SPFlagDefinition | DISubprogram::SPFlagOptimized);
196     F->setSubprogram(SP);
197     auto Scope = DIB.createLexicalBlockFile(SP, File, 0);
198     DIB.finalize();
199     DL = DILocation::get(Ctx, 3, 7, Scope);
200   }
201 
202   void TearDown() override {
203     BB = nullptr;
204     M.reset();
205   }
206 
207   /// Create a function with a simple loop that calls printf using the logical
208   /// loop counter for use with tests that need a CanonicalLoopInfo object.
209   CanonicalLoopInfo *buildSingleLoopFunction(DebugLoc DL,
210                                              OpenMPIRBuilder &OMPBuilder,
211                                              int UseIVBits,
212                                              CallInst **Call = nullptr,
213                                              BasicBlock **BodyCode = nullptr) {
214     OMPBuilder.initialize();
215     F->setName("func");
216 
217     IRBuilder<> Builder(BB);
218     OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
219     Value *TripCount = F->getArg(0);
220 
221     Type *IVType = Type::getIntNTy(Builder.getContext(), UseIVBits);
222     Value *CastedTripCount =
223         Builder.CreateZExtOrTrunc(TripCount, IVType, "tripcount");
224 
225     auto LoopBodyGenCB = [&](OpenMPIRBuilder::InsertPointTy CodeGenIP,
226                              llvm::Value *LC) {
227       Builder.restoreIP(CodeGenIP);
228       if (BodyCode)
229         *BodyCode = Builder.GetInsertBlock();
230 
231       // Add something that consumes the induction variable to the body.
232       CallInst *CallInst = createPrintfCall(Builder, "%d\\n", {LC});
233       if (Call)
234         *Call = CallInst;
235 
236       return Error::success();
237     };
238     Expected<CanonicalLoopInfo *> LoopResult =
239         OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, CastedTripCount);
240     assert(LoopResult && "unexpected error");
241     CanonicalLoopInfo *Loop = *LoopResult;
242 
243     // Finalize the function.
244     Builder.restoreIP(Loop->getAfterIP());
245     Builder.CreateRetVoid();
246 
247     return Loop;
248   }
249 
250   LLVMContext Ctx;
251   std::unique_ptr<Module> M;
252   Function *F;
253   BasicBlock *BB;
254   DebugLoc DL;
255 };
256 
257 class OpenMPIRBuilderTestWithParams
258     : public OpenMPIRBuilderTest,
259       public ::testing::WithParamInterface<omp::OMPScheduleType> {};
260 
261 class OpenMPIRBuilderTestWithIVBits
262     : public OpenMPIRBuilderTest,
263       public ::testing::WithParamInterface<int> {};
264 
265 // Returns the value stored in the given allocation. Returns null if the given
266 // value is not a result of an InstTy instruction, if no value is stored or if
267 // there is more than one store.
268 template <typename InstTy> static Value *findStoredValue(Value *AllocaValue) {
269   Instruction *Inst = dyn_cast<InstTy>(AllocaValue);
270   if (!Inst)
271     return nullptr;
272   StoreInst *Store = nullptr;
273   for (Use &U : Inst->uses()) {
274     if (auto *CandidateStore = dyn_cast<StoreInst>(U.getUser())) {
275       EXPECT_EQ(Store, nullptr);
276       Store = CandidateStore;
277     }
278   }
279   if (!Store)
280     return nullptr;
281   return Store->getValueOperand();
282 }
283 
284 // Returns the value stored in the aggregate argument of an outlined function,
285 // or nullptr if it is not found.
286 static Value *findStoredValueInAggregateAt(LLVMContext &Ctx, Value *Aggregate,
287                                            unsigned Idx) {
288   GetElementPtrInst *GEPAtIdx = nullptr;
289   // Find GEP instruction at that index.
290   for (User *Usr : Aggregate->users()) {
291     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Usr);
292     if (!GEP)
293       continue;
294 
295     if (GEP->getOperand(2) != ConstantInt::get(Type::getInt32Ty(Ctx), Idx))
296       continue;
297 
298     EXPECT_EQ(GEPAtIdx, nullptr);
299     GEPAtIdx = GEP;
300   }
301 
302   EXPECT_NE(GEPAtIdx, nullptr);
303   EXPECT_EQ(GEPAtIdx->getNumUses(), 1U);
304 
305   // Find the value stored to the aggregate.
306   StoreInst *StoreToAgg = dyn_cast<StoreInst>(*GEPAtIdx->user_begin());
307   Value *StoredAggValue = StoreToAgg->getValueOperand();
308 
309   Value *StoredValue = nullptr;
310 
311   // Find the value stored to the value stored in the aggregate.
312   for (User *Usr : StoredAggValue->users()) {
313     StoreInst *Store = dyn_cast<StoreInst>(Usr);
314     if (!Store)
315       continue;
316 
317     if (Store->getPointerOperand() != StoredAggValue)
318       continue;
319 
320     EXPECT_EQ(StoredValue, nullptr);
321     StoredValue = Store->getValueOperand();
322   }
323 
324   return StoredValue;
325 }
326 
327 // Returns the aggregate that the value is originating from.
328 static Value *findAggregateFromValue(Value *V) {
329   // Expects a load instruction that loads from the aggregate.
330   LoadInst *Load = dyn_cast<LoadInst>(V);
331   EXPECT_NE(Load, nullptr);
332   // Find the GEP instruction used in the load instruction.
333   GetElementPtrInst *GEP =
334       dyn_cast<GetElementPtrInst>(Load->getPointerOperand());
335   EXPECT_NE(GEP, nullptr);
336   // Find the aggregate used in the GEP instruction.
337   Value *Aggregate = GEP->getPointerOperand();
338 
339   return Aggregate;
340 }
341 
342 TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
343   OpenMPIRBuilder OMPBuilder(*M);
344   OMPBuilder.initialize();
345 
346   IRBuilder<> Builder(BB);
347 
348   OpenMPIRBuilder::InsertPointOrErrorTy BarrierIP1 =
349       OMPBuilder.createBarrier({IRBuilder<>::InsertPoint()}, OMPD_for);
350   assert(BarrierIP1 && "unexpected error");
351   EXPECT_TRUE(M->global_empty());
352   EXPECT_EQ(M->size(), 1U);
353   EXPECT_EQ(F->size(), 1U);
354   EXPECT_EQ(BB->size(), 0U);
355 
356   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
357   OpenMPIRBuilder::InsertPointOrErrorTy BarrierIP2 =
358       OMPBuilder.createBarrier(Loc, OMPD_for);
359   assert(BarrierIP2 && "unexpected error");
360   EXPECT_FALSE(M->global_empty());
361   EXPECT_EQ(M->size(), 3U);
362   EXPECT_EQ(F->size(), 1U);
363   EXPECT_EQ(BB->size(), 2U);
364 
365   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
366   EXPECT_NE(GTID, nullptr);
367   EXPECT_EQ(GTID->arg_size(), 1U);
368   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
369   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
370   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
371 
372   CallInst *Barrier = dyn_cast<CallInst>(GTID->getNextNode());
373   EXPECT_NE(Barrier, nullptr);
374   EXPECT_EQ(Barrier->arg_size(), 2U);
375   EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_barrier");
376   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
377   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
378 
379   EXPECT_EQ(cast<CallInst>(Barrier)->getArgOperand(1), GTID);
380 
381   Builder.CreateUnreachable();
382   EXPECT_FALSE(verifyModule(*M, &errs()));
383 }
384 
385 TEST_F(OpenMPIRBuilderTest, CreateCancel) {
386   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
387   OpenMPIRBuilder OMPBuilder(*M);
388   OMPBuilder.initialize();
389 
390   BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
391   new UnreachableInst(Ctx, CBB);
392   auto FiniCB = [&](InsertPointTy IP) {
393     ASSERT_NE(IP.getBlock(), nullptr);
394     ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
395     BranchInst::Create(CBB, IP.getBlock());
396   };
397   OMPBuilder.pushFinalizationCB({FINICB_WRAPPER(FiniCB), OMPD_parallel, true});
398 
399   IRBuilder<> Builder(BB);
400 
401   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
402   OpenMPIRBuilder::InsertPointOrErrorTy NewIP =
403       OMPBuilder.createCancel(Loc, nullptr, OMPD_parallel);
404   assert(NewIP && "unexpected error");
405   Builder.restoreIP(*NewIP);
406   EXPECT_FALSE(M->global_empty());
407   EXPECT_EQ(M->size(), 4U);
408   EXPECT_EQ(F->size(), 4U);
409   EXPECT_EQ(BB->size(), 4U);
410 
411   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
412   EXPECT_NE(GTID, nullptr);
413   EXPECT_EQ(GTID->arg_size(), 1U);
414   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
415   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
416   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
417 
418   CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
419   EXPECT_NE(Cancel, nullptr);
420   EXPECT_EQ(Cancel->arg_size(), 3U);
421   EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
422   EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
423   EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
424   EXPECT_EQ(Cancel->getNumUses(), 1U);
425   Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
426   EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
427   EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP->getBlock());
428   EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U);
429   CallInst *GTID1 = dyn_cast<CallInst>(&CancelBBTI->getSuccessor(1)->front());
430   EXPECT_NE(GTID1, nullptr);
431   EXPECT_EQ(GTID1->arg_size(), 1U);
432   EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num");
433   EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory());
434   EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory());
435   CallInst *Barrier = dyn_cast<CallInst>(GTID1->getNextNode());
436   EXPECT_NE(Barrier, nullptr);
437   EXPECT_EQ(Barrier->arg_size(), 2U);
438   EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier");
439   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
440   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
441   EXPECT_EQ(Barrier->getNumUses(), 0U);
442   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
443             1U);
444   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB);
445 
446   EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
447 
448   OMPBuilder.popFinalizationCB();
449 
450   Builder.CreateUnreachable();
451   EXPECT_FALSE(verifyModule(*M, &errs()));
452 }
453 
454 TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
455   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
456   OpenMPIRBuilder OMPBuilder(*M);
457   OMPBuilder.initialize();
458 
459   BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
460   new UnreachableInst(Ctx, CBB);
461   auto FiniCB = [&](InsertPointTy IP) {
462     ASSERT_NE(IP.getBlock(), nullptr);
463     ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
464     BranchInst::Create(CBB, IP.getBlock());
465   };
466   OMPBuilder.pushFinalizationCB({FINICB_WRAPPER(FiniCB), OMPD_parallel, true});
467 
468   IRBuilder<> Builder(BB);
469 
470   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
471   OpenMPIRBuilder::InsertPointOrErrorTy NewIP =
472       OMPBuilder.createCancel(Loc, Builder.getTrue(), OMPD_parallel);
473   assert(NewIP && "unexpected error");
474   Builder.restoreIP(*NewIP);
475   EXPECT_FALSE(M->global_empty());
476   EXPECT_EQ(M->size(), 4U);
477   EXPECT_EQ(F->size(), 7U);
478   EXPECT_EQ(BB->size(), 1U);
479   ASSERT_TRUE(isa<BranchInst>(BB->getTerminator()));
480   ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U);
481   BB = BB->getTerminator()->getSuccessor(0);
482   EXPECT_EQ(BB->size(), 4U);
483 
484   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
485   EXPECT_NE(GTID, nullptr);
486   EXPECT_EQ(GTID->arg_size(), 1U);
487   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
488   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
489   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
490 
491   CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
492   EXPECT_NE(Cancel, nullptr);
493   EXPECT_EQ(Cancel->arg_size(), 3U);
494   EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
495   EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
496   EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
497   EXPECT_EQ(Cancel->getNumUses(), 1U);
498   Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
499   EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
500   EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U);
501   EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(),
502             NewIP->getBlock());
503   EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U);
504   CallInst *GTID1 = dyn_cast<CallInst>(&CancelBBTI->getSuccessor(1)->front());
505   EXPECT_NE(GTID1, nullptr);
506   EXPECT_EQ(GTID1->arg_size(), 1U);
507   EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num");
508   EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory());
509   EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory());
510   CallInst *Barrier = dyn_cast<CallInst>(GTID1->getNextNode());
511   EXPECT_NE(Barrier, nullptr);
512   EXPECT_EQ(Barrier->arg_size(), 2U);
513   EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier");
514   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
515   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
516   EXPECT_EQ(Barrier->getNumUses(), 0U);
517   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
518             1U);
519   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB);
520 
521   EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
522 
523   OMPBuilder.popFinalizationCB();
524 
525   Builder.CreateUnreachable();
526   EXPECT_FALSE(verifyModule(*M, &errs()));
527 }
528 
529 TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
530   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
531   OpenMPIRBuilder OMPBuilder(*M);
532   OMPBuilder.initialize();
533 
534   BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
535   new UnreachableInst(Ctx, CBB);
536   auto FiniCB = [&](InsertPointTy IP) {
537     ASSERT_NE(IP.getBlock(), nullptr);
538     ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
539     BranchInst::Create(CBB, IP.getBlock());
540   };
541   OMPBuilder.pushFinalizationCB({FINICB_WRAPPER(FiniCB), OMPD_parallel, true});
542 
543   IRBuilder<> Builder(BB);
544 
545   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
546   OpenMPIRBuilder::InsertPointOrErrorTy NewIP =
547       OMPBuilder.createBarrier(Loc, OMPD_for);
548   assert(NewIP && "unexpected error");
549   Builder.restoreIP(*NewIP);
550   EXPECT_FALSE(M->global_empty());
551   EXPECT_EQ(M->size(), 3U);
552   EXPECT_EQ(F->size(), 4U);
553   EXPECT_EQ(BB->size(), 4U);
554 
555   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
556   EXPECT_NE(GTID, nullptr);
557   EXPECT_EQ(GTID->arg_size(), 1U);
558   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
559   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
560   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
561 
562   CallInst *Barrier = dyn_cast<CallInst>(GTID->getNextNode());
563   EXPECT_NE(Barrier, nullptr);
564   EXPECT_EQ(Barrier->arg_size(), 2U);
565   EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier");
566   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
567   EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
568   EXPECT_EQ(Barrier->getNumUses(), 1U);
569   Instruction *BarrierBBTI = Barrier->getParent()->getTerminator();
570   EXPECT_EQ(BarrierBBTI->getNumSuccessors(), 2U);
571   EXPECT_EQ(BarrierBBTI->getSuccessor(0), NewIP->getBlock());
572   EXPECT_EQ(BarrierBBTI->getSuccessor(1)->size(), 1U);
573   EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
574             1U);
575   EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
576             CBB);
577 
578   EXPECT_EQ(cast<CallInst>(Barrier)->getArgOperand(1), GTID);
579 
580   OMPBuilder.popFinalizationCB();
581 
582   Builder.CreateUnreachable();
583   EXPECT_FALSE(verifyModule(*M, &errs()));
584 }
585 
586 TEST_F(OpenMPIRBuilderTest, DbgLoc) {
587   OpenMPIRBuilder OMPBuilder(*M);
588   OMPBuilder.initialize();
589   F->setName("func");
590 
591   IRBuilder<> Builder(BB);
592 
593   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
594   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
595       OMPBuilder.createBarrier(Loc, OMPD_for);
596   assert(AfterIP && "unexpected error");
597   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
598   CallInst *Barrier = dyn_cast<CallInst>(GTID->getNextNode());
599   EXPECT_EQ(GTID->getDebugLoc(), DL);
600   EXPECT_EQ(Barrier->getDebugLoc(), DL);
601   EXPECT_TRUE(isa<GlobalVariable>(Barrier->getOperand(0)));
602   if (!isa<GlobalVariable>(Barrier->getOperand(0)))
603     return;
604   GlobalVariable *Ident = cast<GlobalVariable>(Barrier->getOperand(0));
605   EXPECT_TRUE(Ident->hasInitializer());
606   if (!Ident->hasInitializer())
607     return;
608   Constant *Initializer = Ident->getInitializer();
609   EXPECT_TRUE(
610       isa<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts()));
611   GlobalVariable *SrcStrGlob =
612       cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
613   if (!SrcStrGlob)
614     return;
615   EXPECT_TRUE(isa<ConstantDataArray>(SrcStrGlob->getInitializer()));
616   ConstantDataArray *SrcSrc =
617       dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
618   if (!SrcSrc)
619     return;
620   EXPECT_EQ(SrcSrc->getAsCString(), ";/src/test.dbg;foo;3;7;;");
621 }
622 
623 TEST_F(OpenMPIRBuilderTest, ParallelSimpleGPU) {
624   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
625   std::string oldDLStr = M->getDataLayoutStr();
626   M->setDataLayout(
627       "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:"
628       "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
629       "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
630   OpenMPIRBuilder OMPBuilder(*M);
631   OMPBuilder.Config.IsTargetDevice = true;
632   OMPBuilder.initialize();
633   F->setName("func");
634   IRBuilder<> Builder(BB);
635   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
636   Builder.CreateBr(EnterBB);
637   Builder.SetInsertPoint(EnterBB);
638   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
639 
640   AllocaInst *PrivAI = nullptr;
641 
642   unsigned NumBodiesGenerated = 0;
643   unsigned NumPrivatizedVars = 0;
644   unsigned NumFinalizationPoints = 0;
645 
646   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
647     ++NumBodiesGenerated;
648 
649     Builder.restoreIP(AllocaIP);
650     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
651     Builder.CreateStore(F->arg_begin(), PrivAI);
652 
653     Builder.restoreIP(CodeGenIP);
654     Value *PrivLoad =
655         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
656     Value *Cmp = Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
657     Instruction *ThenTerm, *ElseTerm;
658     SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
659                                   &ThenTerm, &ElseTerm);
660     return Error::success();
661   };
662 
663   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
664                     Value &Orig, Value &Inner,
665                     Value *&ReplacementValue) -> InsertPointTy {
666     ++NumPrivatizedVars;
667 
668     if (!isa<AllocaInst>(Orig)) {
669       EXPECT_EQ(&Orig, F->arg_begin());
670       ReplacementValue = &Inner;
671       return CodeGenIP;
672     }
673 
674     // Since the original value is an allocation, it has a pointer type and
675     // therefore no additional wrapping should happen.
676     EXPECT_EQ(&Orig, &Inner);
677 
678     // Trivial copy (=firstprivate).
679     Builder.restoreIP(AllocaIP);
680     Type *VTy = ReplacementValue->getType();
681     Value *V = Builder.CreateLoad(VTy, &Inner, Orig.getName() + ".reload");
682     ReplacementValue = Builder.CreateAlloca(VTy, 0, Orig.getName() + ".copy");
683     Builder.restoreIP(CodeGenIP);
684     Builder.CreateStore(V, ReplacementValue);
685     return CodeGenIP;
686   };
687 
688   auto FiniCB = [&](InsertPointTy CodeGenIP) {
689     ++NumFinalizationPoints;
690     return Error::success();
691   };
692 
693   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
694                                     F->getEntryBlock().getFirstInsertionPt());
695   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
696       OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
697                                 nullptr, nullptr, OMP_PROC_BIND_default, false);
698   assert(AfterIP && "unexpected error");
699 
700   EXPECT_EQ(NumBodiesGenerated, 1U);
701   EXPECT_EQ(NumPrivatizedVars, 1U);
702   EXPECT_EQ(NumFinalizationPoints, 1U);
703 
704   Builder.restoreIP(*AfterIP);
705   Builder.CreateRetVoid();
706 
707   OMPBuilder.finalize();
708   Function *OutlinedFn = PrivAI->getFunction();
709   EXPECT_FALSE(verifyModule(*M, &errs()));
710   EXPECT_NE(OutlinedFn, F);
711   EXPECT_TRUE(OutlinedFn->hasFnAttribute(Attribute::NoUnwind));
712   EXPECT_TRUE(OutlinedFn->hasParamAttribute(0, Attribute::NoAlias));
713   EXPECT_TRUE(OutlinedFn->hasParamAttribute(1, Attribute::NoAlias));
714 
715   EXPECT_TRUE(OutlinedFn->hasInternalLinkage());
716   EXPECT_EQ(OutlinedFn->arg_size(), 3U);
717   // Make sure that arguments are pointers in 0 address address space
718   EXPECT_EQ(OutlinedFn->getArg(0)->getType(),
719             PointerType::get(M->getContext(), 0));
720   EXPECT_EQ(OutlinedFn->getArg(1)->getType(),
721             PointerType::get(M->getContext(), 0));
722   EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
723             PointerType::get(M->getContext(), 0));
724   EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
725   EXPECT_EQ(OutlinedFn->getNumUses(), 1U);
726   User *Usr = OutlinedFn->user_back();
727   ASSERT_TRUE(isa<CallInst>(Usr));
728   CallInst *Parallel51CI = dyn_cast<CallInst>(Usr);
729   ASSERT_NE(Parallel51CI, nullptr);
730 
731   EXPECT_EQ(Parallel51CI->getCalledFunction()->getName(), "__kmpc_parallel_51");
732   EXPECT_EQ(Parallel51CI->arg_size(), 9U);
733   EXPECT_EQ(Parallel51CI->getArgOperand(5), OutlinedFn);
734   EXPECT_TRUE(
735       isa<GlobalVariable>(Parallel51CI->getArgOperand(0)->stripPointerCasts()));
736   EXPECT_EQ(Parallel51CI, Usr);
737   M->setDataLayout(oldDLStr);
738 }
739 
740 TEST_F(OpenMPIRBuilderTest, ParallelSimple) {
741   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
742   OpenMPIRBuilder OMPBuilder(*M);
743   OMPBuilder.Config.IsTargetDevice = false;
744   OMPBuilder.initialize();
745   F->setName("func");
746   IRBuilder<> Builder(BB);
747 
748   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
749   Builder.CreateBr(EnterBB);
750   Builder.SetInsertPoint(EnterBB);
751   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
752 
753   AllocaInst *PrivAI = nullptr;
754 
755   unsigned NumBodiesGenerated = 0;
756   unsigned NumPrivatizedVars = 0;
757   unsigned NumFinalizationPoints = 0;
758 
759   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
760     ++NumBodiesGenerated;
761 
762     Builder.restoreIP(AllocaIP);
763     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
764     Builder.CreateStore(F->arg_begin(), PrivAI);
765 
766     Builder.restoreIP(CodeGenIP);
767     Value *PrivLoad =
768         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
769     Value *Cmp = Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
770     Instruction *ThenTerm, *ElseTerm;
771     SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
772                                   &ThenTerm, &ElseTerm);
773     return Error::success();
774   };
775 
776   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
777                     Value &Orig, Value &Inner,
778                     Value *&ReplacementValue) -> InsertPointTy {
779     ++NumPrivatizedVars;
780 
781     if (!isa<AllocaInst>(Orig)) {
782       EXPECT_EQ(&Orig, F->arg_begin());
783       ReplacementValue = &Inner;
784       return CodeGenIP;
785     }
786 
787     // Since the original value is an allocation, it has a pointer type and
788     // therefore no additional wrapping should happen.
789     EXPECT_EQ(&Orig, &Inner);
790 
791     // Trivial copy (=firstprivate).
792     Builder.restoreIP(AllocaIP);
793     Type *VTy = ReplacementValue->getType();
794     Value *V = Builder.CreateLoad(VTy, &Inner, Orig.getName() + ".reload");
795     ReplacementValue = Builder.CreateAlloca(VTy, 0, Orig.getName() + ".copy");
796     Builder.restoreIP(CodeGenIP);
797     Builder.CreateStore(V, ReplacementValue);
798     return CodeGenIP;
799   };
800 
801   auto FiniCB = [&](InsertPointTy CodeGenIP) {
802     ++NumFinalizationPoints;
803     return Error::success();
804   };
805 
806   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
807                                     F->getEntryBlock().getFirstInsertionPt());
808   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
809       OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
810                                 nullptr, nullptr, OMP_PROC_BIND_default, false);
811   assert(AfterIP && "unexpected error");
812   EXPECT_EQ(NumBodiesGenerated, 1U);
813   EXPECT_EQ(NumPrivatizedVars, 1U);
814   EXPECT_EQ(NumFinalizationPoints, 1U);
815 
816   Builder.restoreIP(*AfterIP);
817   Builder.CreateRetVoid();
818 
819   OMPBuilder.finalize();
820 
821   EXPECT_NE(PrivAI, nullptr);
822   Function *OutlinedFn = PrivAI->getFunction();
823   EXPECT_NE(F, OutlinedFn);
824   EXPECT_FALSE(verifyModule(*M, &errs()));
825   EXPECT_TRUE(OutlinedFn->hasFnAttribute(Attribute::NoUnwind));
826   EXPECT_TRUE(OutlinedFn->hasParamAttribute(0, Attribute::NoAlias));
827   EXPECT_TRUE(OutlinedFn->hasParamAttribute(1, Attribute::NoAlias));
828 
829   EXPECT_TRUE(OutlinedFn->hasInternalLinkage());
830   EXPECT_EQ(OutlinedFn->arg_size(), 3U);
831 
832   EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
833   EXPECT_EQ(OutlinedFn->getNumUses(), 1U);
834   User *Usr = OutlinedFn->user_back();
835   ASSERT_TRUE(isa<CallInst>(Usr));
836   CallInst *ForkCI = dyn_cast<CallInst>(Usr);
837   ASSERT_NE(ForkCI, nullptr);
838 
839   EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call");
840   EXPECT_EQ(ForkCI->arg_size(), 4U);
841   EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
842   EXPECT_EQ(ForkCI->getArgOperand(1),
843             ConstantInt::get(Type::getInt32Ty(Ctx), 1U));
844   EXPECT_EQ(ForkCI, Usr);
845   Value *StoredValue =
846       findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0);
847   EXPECT_EQ(StoredValue, F->arg_begin());
848 }
849 
850 TEST_F(OpenMPIRBuilderTest, ParallelNested) {
851   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
852   OpenMPIRBuilder OMPBuilder(*M);
853   OMPBuilder.Config.IsTargetDevice = false;
854   OMPBuilder.initialize();
855   F->setName("func");
856   IRBuilder<> Builder(BB);
857 
858   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
859   Builder.CreateBr(EnterBB);
860   Builder.SetInsertPoint(EnterBB);
861   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
862 
863   unsigned NumInnerBodiesGenerated = 0;
864   unsigned NumOuterBodiesGenerated = 0;
865   unsigned NumFinalizationPoints = 0;
866 
867   auto InnerBodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
868     ++NumInnerBodiesGenerated;
869     return Error::success();
870   };
871 
872   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
873                     Value &Orig, Value &Inner,
874                     Value *&ReplacementValue) -> InsertPointTy {
875     // Trivial copy (=firstprivate).
876     Builder.restoreIP(AllocaIP);
877     Type *VTy = ReplacementValue->getType();
878     Value *V = Builder.CreateLoad(VTy, &Inner, Orig.getName() + ".reload");
879     ReplacementValue = Builder.CreateAlloca(VTy, 0, Orig.getName() + ".copy");
880     Builder.restoreIP(CodeGenIP);
881     Builder.CreateStore(V, ReplacementValue);
882     return CodeGenIP;
883   };
884 
885   auto FiniCB = [&](InsertPointTy CodeGenIP) {
886     ++NumFinalizationPoints;
887     return Error::success();
888   };
889 
890   auto OuterBodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
891     ++NumOuterBodiesGenerated;
892     Builder.restoreIP(CodeGenIP);
893     BasicBlock *CGBB = CodeGenIP.getBlock();
894     BasicBlock *NewBB = SplitBlock(CGBB, &*CodeGenIP.getPoint());
895     CGBB->getTerminator()->eraseFromParent();
896 
897     OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createParallel(
898         InsertPointTy(CGBB, CGBB->end()), AllocaIP, InnerBodyGenCB, PrivCB,
899         FiniCB, nullptr, nullptr, OMP_PROC_BIND_default, false);
900     assert(AfterIP && "unexpected error");
901 
902     Builder.restoreIP(*AfterIP);
903     Builder.CreateBr(NewBB);
904     return Error::success();
905   };
906 
907   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
908                                     F->getEntryBlock().getFirstInsertionPt());
909   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
910       OMPBuilder.createParallel(Loc, AllocaIP, OuterBodyGenCB, PrivCB, FiniCB,
911                                 nullptr, nullptr, OMP_PROC_BIND_default, false);
912   assert(AfterIP && "unexpected error");
913 
914   EXPECT_EQ(NumInnerBodiesGenerated, 1U);
915   EXPECT_EQ(NumOuterBodiesGenerated, 1U);
916   EXPECT_EQ(NumFinalizationPoints, 2U);
917 
918   Builder.restoreIP(*AfterIP);
919   Builder.CreateRetVoid();
920 
921   OMPBuilder.finalize();
922 
923   EXPECT_EQ(M->size(), 5U);
924   for (Function &OutlinedFn : *M) {
925     if (F == &OutlinedFn || OutlinedFn.isDeclaration())
926       continue;
927     EXPECT_FALSE(verifyModule(*M, &errs()));
928     EXPECT_TRUE(OutlinedFn.hasFnAttribute(Attribute::NoUnwind));
929     EXPECT_TRUE(OutlinedFn.hasParamAttribute(0, Attribute::NoAlias));
930     EXPECT_TRUE(OutlinedFn.hasParamAttribute(1, Attribute::NoAlias));
931 
932     EXPECT_TRUE(OutlinedFn.hasInternalLinkage());
933     EXPECT_EQ(OutlinedFn.arg_size(), 2U);
934 
935     EXPECT_EQ(OutlinedFn.getNumUses(), 1U);
936     User *Usr = OutlinedFn.user_back();
937     ASSERT_TRUE(isa<CallInst>(Usr));
938     CallInst *ForkCI = dyn_cast<CallInst>(Usr);
939     ASSERT_NE(ForkCI, nullptr);
940 
941     EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call");
942     EXPECT_EQ(ForkCI->arg_size(), 3U);
943     EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
944     EXPECT_EQ(ForkCI->getArgOperand(1),
945               ConstantInt::get(Type::getInt32Ty(Ctx), 0U));
946     EXPECT_EQ(ForkCI, Usr);
947   }
948 }
949 
950 TEST_F(OpenMPIRBuilderTest, ParallelNested2Inner) {
951   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
952   OpenMPIRBuilder OMPBuilder(*M);
953   OMPBuilder.Config.IsTargetDevice = false;
954   OMPBuilder.initialize();
955   F->setName("func");
956   IRBuilder<> Builder(BB);
957 
958   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
959   Builder.CreateBr(EnterBB);
960   Builder.SetInsertPoint(EnterBB);
961   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
962 
963   unsigned NumInnerBodiesGenerated = 0;
964   unsigned NumOuterBodiesGenerated = 0;
965   unsigned NumFinalizationPoints = 0;
966 
967   auto InnerBodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
968     ++NumInnerBodiesGenerated;
969     return Error::success();
970   };
971 
972   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
973                     Value &Orig, Value &Inner,
974                     Value *&ReplacementValue) -> InsertPointTy {
975     // Trivial copy (=firstprivate).
976     Builder.restoreIP(AllocaIP);
977     Type *VTy = ReplacementValue->getType();
978     Value *V = Builder.CreateLoad(VTy, &Inner, Orig.getName() + ".reload");
979     ReplacementValue = Builder.CreateAlloca(VTy, 0, Orig.getName() + ".copy");
980     Builder.restoreIP(CodeGenIP);
981     Builder.CreateStore(V, ReplacementValue);
982     return CodeGenIP;
983   };
984 
985   auto FiniCB = [&](InsertPointTy CodeGenIP) {
986     ++NumFinalizationPoints;
987     return Error::success();
988   };
989 
990   auto OuterBodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
991     ++NumOuterBodiesGenerated;
992     Builder.restoreIP(CodeGenIP);
993     BasicBlock *CGBB = CodeGenIP.getBlock();
994     BasicBlock *NewBB1 = SplitBlock(CGBB, &*CodeGenIP.getPoint());
995     BasicBlock *NewBB2 = SplitBlock(NewBB1, &*NewBB1->getFirstInsertionPt());
996     CGBB->getTerminator()->eraseFromParent();
997     ;
998     NewBB1->getTerminator()->eraseFromParent();
999     ;
1000 
1001     OpenMPIRBuilder::InsertPointOrErrorTy AfterIP1 = OMPBuilder.createParallel(
1002         InsertPointTy(CGBB, CGBB->end()), AllocaIP, InnerBodyGenCB, PrivCB,
1003         FiniCB, nullptr, nullptr, OMP_PROC_BIND_default, false);
1004     assert(AfterIP1 && "unexpected error");
1005 
1006     Builder.restoreIP(*AfterIP1);
1007     Builder.CreateBr(NewBB1);
1008 
1009     OpenMPIRBuilder::InsertPointOrErrorTy AfterIP2 = OMPBuilder.createParallel(
1010         InsertPointTy(NewBB1, NewBB1->end()), AllocaIP, InnerBodyGenCB, PrivCB,
1011         FiniCB, nullptr, nullptr, OMP_PROC_BIND_default, false);
1012     assert(AfterIP2 && "unexpected error");
1013 
1014     Builder.restoreIP(*AfterIP2);
1015     Builder.CreateBr(NewBB2);
1016     return Error::success();
1017   };
1018 
1019   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
1020                                     F->getEntryBlock().getFirstInsertionPt());
1021   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
1022       OMPBuilder.createParallel(Loc, AllocaIP, OuterBodyGenCB, PrivCB, FiniCB,
1023                                 nullptr, nullptr, OMP_PROC_BIND_default, false);
1024   assert(AfterIP && "unexpected error");
1025 
1026   EXPECT_EQ(NumInnerBodiesGenerated, 2U);
1027   EXPECT_EQ(NumOuterBodiesGenerated, 1U);
1028   EXPECT_EQ(NumFinalizationPoints, 3U);
1029 
1030   Builder.restoreIP(*AfterIP);
1031   Builder.CreateRetVoid();
1032 
1033   OMPBuilder.finalize();
1034 
1035   EXPECT_EQ(M->size(), 6U);
1036   for (Function &OutlinedFn : *M) {
1037     if (F == &OutlinedFn || OutlinedFn.isDeclaration())
1038       continue;
1039     EXPECT_FALSE(verifyModule(*M, &errs()));
1040     EXPECT_TRUE(OutlinedFn.hasFnAttribute(Attribute::NoUnwind));
1041     EXPECT_TRUE(OutlinedFn.hasParamAttribute(0, Attribute::NoAlias));
1042     EXPECT_TRUE(OutlinedFn.hasParamAttribute(1, Attribute::NoAlias));
1043 
1044     EXPECT_TRUE(OutlinedFn.hasInternalLinkage());
1045     EXPECT_EQ(OutlinedFn.arg_size(), 2U);
1046 
1047     unsigned NumAllocas = 0;
1048     for (Instruction &I : instructions(OutlinedFn))
1049       NumAllocas += isa<AllocaInst>(I);
1050     EXPECT_EQ(NumAllocas, 1U);
1051 
1052     EXPECT_EQ(OutlinedFn.getNumUses(), 1U);
1053     User *Usr = OutlinedFn.user_back();
1054     ASSERT_TRUE(isa<CallInst>(Usr));
1055     CallInst *ForkCI = dyn_cast<CallInst>(Usr);
1056     ASSERT_NE(ForkCI, nullptr);
1057 
1058     EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call");
1059     EXPECT_EQ(ForkCI->arg_size(), 3U);
1060     EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
1061     EXPECT_EQ(ForkCI->getArgOperand(1),
1062               ConstantInt::get(Type::getInt32Ty(Ctx), 0U));
1063     EXPECT_EQ(ForkCI, Usr);
1064   }
1065 }
1066 
1067 TEST_F(OpenMPIRBuilderTest, ParallelIfCond) {
1068   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1069   OpenMPIRBuilder OMPBuilder(*M);
1070   OMPBuilder.Config.IsTargetDevice = false;
1071   OMPBuilder.initialize();
1072   F->setName("func");
1073   IRBuilder<> Builder(BB);
1074 
1075   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
1076   Builder.CreateBr(EnterBB);
1077   Builder.SetInsertPoint(EnterBB);
1078   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1079 
1080   AllocaInst *PrivAI = nullptr;
1081 
1082   unsigned NumBodiesGenerated = 0;
1083   unsigned NumPrivatizedVars = 0;
1084   unsigned NumFinalizationPoints = 0;
1085 
1086   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1087     ++NumBodiesGenerated;
1088 
1089     Builder.restoreIP(AllocaIP);
1090     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
1091     Builder.CreateStore(F->arg_begin(), PrivAI);
1092 
1093     Builder.restoreIP(CodeGenIP);
1094     Value *PrivLoad =
1095         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
1096     Value *Cmp = Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
1097     Instruction *ThenTerm, *ElseTerm;
1098     SplitBlockAndInsertIfThenElse(Cmp, &*Builder.GetInsertPoint(), &ThenTerm,
1099                                   &ElseTerm);
1100     return Error::success();
1101   };
1102 
1103   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
1104                     Value &Orig, Value &Inner,
1105                     Value *&ReplacementValue) -> InsertPointTy {
1106     ++NumPrivatizedVars;
1107 
1108     if (!isa<AllocaInst>(Orig)) {
1109       EXPECT_EQ(&Orig, F->arg_begin());
1110       ReplacementValue = &Inner;
1111       return CodeGenIP;
1112     }
1113 
1114     // Since the original value is an allocation, it has a pointer type and
1115     // therefore no additional wrapping should happen.
1116     EXPECT_EQ(&Orig, &Inner);
1117 
1118     // Trivial copy (=firstprivate).
1119     Builder.restoreIP(AllocaIP);
1120     Type *VTy = ReplacementValue->getType();
1121     Value *V = Builder.CreateLoad(VTy, &Inner, Orig.getName() + ".reload");
1122     ReplacementValue = Builder.CreateAlloca(VTy, 0, Orig.getName() + ".copy");
1123     Builder.restoreIP(CodeGenIP);
1124     Builder.CreateStore(V, ReplacementValue);
1125     return CodeGenIP;
1126   };
1127 
1128   auto FiniCB = [&](InsertPointTy CodeGenIP) {
1129     ++NumFinalizationPoints;
1130     // No destructors.
1131     return Error::success();
1132   };
1133 
1134   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
1135                                     F->getEntryBlock().getFirstInsertionPt());
1136   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
1137       OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
1138                                 Builder.CreateIsNotNull(F->arg_begin()),
1139                                 nullptr, OMP_PROC_BIND_default, false);
1140   assert(AfterIP && "unexpected error");
1141 
1142   EXPECT_EQ(NumBodiesGenerated, 1U);
1143   EXPECT_EQ(NumPrivatizedVars, 1U);
1144   EXPECT_EQ(NumFinalizationPoints, 1U);
1145 
1146   Builder.restoreIP(*AfterIP);
1147   Builder.CreateRetVoid();
1148   OMPBuilder.finalize();
1149 
1150   EXPECT_NE(PrivAI, nullptr);
1151   Function *OutlinedFn = PrivAI->getFunction();
1152   EXPECT_NE(F, OutlinedFn);
1153   EXPECT_FALSE(verifyModule(*M, &errs()));
1154 
1155   EXPECT_TRUE(OutlinedFn->hasInternalLinkage());
1156   EXPECT_EQ(OutlinedFn->arg_size(), 3U);
1157 
1158   EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
1159   ASSERT_EQ(OutlinedFn->getNumUses(), 1U);
1160 
1161   CallInst *ForkCI = nullptr;
1162   for (User *Usr : OutlinedFn->users()) {
1163     ASSERT_TRUE(isa<CallInst>(Usr));
1164     ForkCI = cast<CallInst>(Usr);
1165   }
1166 
1167   EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call_if");
1168   EXPECT_EQ(ForkCI->arg_size(), 5U);
1169   EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
1170   EXPECT_EQ(ForkCI->getArgOperand(1),
1171             ConstantInt::get(Type::getInt32Ty(Ctx), 1));
1172   EXPECT_EQ(ForkCI->getArgOperand(3)->getType(), Type::getInt32Ty(Ctx));
1173 }
1174 
1175 TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
1176   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1177   OpenMPIRBuilder OMPBuilder(*M);
1178   OMPBuilder.Config.IsTargetDevice = false;
1179   OMPBuilder.initialize();
1180   F->setName("func");
1181   IRBuilder<> Builder(BB);
1182 
1183   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
1184   Builder.CreateBr(EnterBB);
1185   Builder.SetInsertPoint(EnterBB);
1186   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1187 
1188   unsigned NumBodiesGenerated = 0;
1189   unsigned NumPrivatizedVars = 0;
1190   unsigned NumFinalizationPoints = 0;
1191 
1192   CallInst *CheckedBarrier = nullptr;
1193   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1194     ++NumBodiesGenerated;
1195 
1196     Builder.restoreIP(CodeGenIP);
1197 
1198     // Create three barriers, two cancel barriers but only one checked.
1199     Function *CBFn, *BFn;
1200 
1201     OpenMPIRBuilder::InsertPointOrErrorTy BarrierIP1 =
1202         OMPBuilder.createBarrier(Builder.saveIP(), OMPD_parallel);
1203     assert(BarrierIP1 && "unexpected error");
1204     Builder.restoreIP(*BarrierIP1);
1205 
1206     CBFn = M->getFunction("__kmpc_cancel_barrier");
1207     BFn = M->getFunction("__kmpc_barrier");
1208     ASSERT_NE(CBFn, nullptr);
1209     ASSERT_EQ(BFn, nullptr);
1210     ASSERT_EQ(CBFn->getNumUses(), 1U);
1211     ASSERT_TRUE(isa<CallInst>(CBFn->user_back()));
1212     ASSERT_EQ(CBFn->user_back()->getNumUses(), 1U);
1213     CheckedBarrier = cast<CallInst>(CBFn->user_back());
1214 
1215     OpenMPIRBuilder::InsertPointOrErrorTy BarrierIP2 =
1216         OMPBuilder.createBarrier(Builder.saveIP(), OMPD_parallel, true);
1217     assert(BarrierIP2 && "unexpected error");
1218     Builder.restoreIP(*BarrierIP2);
1219     CBFn = M->getFunction("__kmpc_cancel_barrier");
1220     BFn = M->getFunction("__kmpc_barrier");
1221     ASSERT_NE(CBFn, nullptr);
1222     ASSERT_NE(BFn, nullptr);
1223     ASSERT_EQ(CBFn->getNumUses(), 1U);
1224     ASSERT_EQ(BFn->getNumUses(), 1U);
1225     ASSERT_TRUE(isa<CallInst>(BFn->user_back()));
1226     ASSERT_EQ(BFn->user_back()->getNumUses(), 0U);
1227 
1228     OpenMPIRBuilder::InsertPointOrErrorTy BarrierIP3 =
1229         OMPBuilder.createBarrier(Builder.saveIP(), OMPD_parallel, false, false);
1230     assert(BarrierIP3 && "unexpected error");
1231     Builder.restoreIP(*BarrierIP3);
1232     ASSERT_EQ(CBFn->getNumUses(), 2U);
1233     ASSERT_EQ(BFn->getNumUses(), 1U);
1234     ASSERT_TRUE(CBFn->user_back() != CheckedBarrier);
1235     ASSERT_TRUE(isa<CallInst>(CBFn->user_back()));
1236     ASSERT_EQ(CBFn->user_back()->getNumUses(), 0U);
1237   };
1238 
1239   auto PrivCB = [&](InsertPointTy, InsertPointTy, Value &V, Value &,
1240                     Value *&) -> InsertPointTy {
1241     ++NumPrivatizedVars;
1242     llvm_unreachable("No privatization callback call expected!");
1243   };
1244 
1245   FunctionType *FakeDestructorTy =
1246       FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)},
1247                         /*isVarArg=*/false);
1248   auto *FakeDestructor = Function::Create(
1249       FakeDestructorTy, Function::ExternalLinkage, "fakeDestructor", M.get());
1250 
1251   auto FiniCB = [&](InsertPointTy IP) {
1252     ++NumFinalizationPoints;
1253     Builder.restoreIP(IP);
1254     Builder.CreateCall(FakeDestructor,
1255                        {Builder.getInt32(NumFinalizationPoints)});
1256     return Error::success();
1257   };
1258 
1259   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
1260                                     F->getEntryBlock().getFirstInsertionPt());
1261   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createParallel(
1262       Loc, AllocaIP, BODYGENCB_WRAPPER(BodyGenCB), PrivCB, FiniCB,
1263       Builder.CreateIsNotNull(F->arg_begin()), nullptr, OMP_PROC_BIND_default,
1264       true);
1265   assert(AfterIP && "unexpected error");
1266 
1267   EXPECT_EQ(NumBodiesGenerated, 1U);
1268   EXPECT_EQ(NumPrivatizedVars, 0U);
1269   EXPECT_EQ(NumFinalizationPoints, 2U);
1270   EXPECT_EQ(FakeDestructor->getNumUses(), 2U);
1271 
1272   Builder.restoreIP(*AfterIP);
1273   Builder.CreateRetVoid();
1274   OMPBuilder.finalize();
1275 
1276   EXPECT_FALSE(verifyModule(*M, &errs()));
1277 
1278   BasicBlock *ExitBB = nullptr;
1279   for (const User *Usr : FakeDestructor->users()) {
1280     const CallInst *CI = dyn_cast<CallInst>(Usr);
1281     ASSERT_EQ(CI->getCalledFunction(), FakeDestructor);
1282     ASSERT_TRUE(isa<BranchInst>(CI->getNextNode()));
1283     ASSERT_EQ(CI->getNextNode()->getNumSuccessors(), 1U);
1284     if (ExitBB)
1285       ASSERT_EQ(CI->getNextNode()->getSuccessor(0), ExitBB);
1286     else
1287       ExitBB = CI->getNextNode()->getSuccessor(0);
1288     ASSERT_EQ(ExitBB->size(), 1U);
1289     if (!isa<ReturnInst>(ExitBB->front())) {
1290       ASSERT_TRUE(isa<BranchInst>(ExitBB->front()));
1291       ASSERT_EQ(cast<BranchInst>(ExitBB->front()).getNumSuccessors(), 1U);
1292       ASSERT_TRUE(isa<ReturnInst>(
1293           cast<BranchInst>(ExitBB->front()).getSuccessor(0)->front()));
1294     }
1295   }
1296 }
1297 
1298 TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) {
1299   OpenMPIRBuilder OMPBuilder(*M);
1300   OMPBuilder.Config.IsTargetDevice = false;
1301   OMPBuilder.initialize();
1302   F->setName("func");
1303   IRBuilder<> Builder(BB);
1304   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1305   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1306 
1307   Type *I32Ty = Type::getInt32Ty(M->getContext());
1308   Type *PtrTy = PointerType::get(M->getContext(), 0);
1309   Type *StructTy = StructType::get(I32Ty, PtrTy);
1310   Type *VoidTy = Type::getVoidTy(M->getContext());
1311   FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty);
1312   FunctionCallee TakeI32Func =
1313       M->getOrInsertFunction("take_i32", VoidTy, I32Ty);
1314   FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", PtrTy);
1315   FunctionCallee TakeI32PtrFunc =
1316       M->getOrInsertFunction("take_i32ptr", VoidTy, PtrTy);
1317   FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy);
1318   FunctionCallee TakeStructFunc =
1319       M->getOrInsertFunction("take_struct", VoidTy, StructTy);
1320   FunctionCallee RetStructPtrFunc =
1321       M->getOrInsertFunction("ret_structptr", PtrTy);
1322   FunctionCallee TakeStructPtrFunc =
1323       M->getOrInsertFunction("take_structPtr", VoidTy, PtrTy);
1324   Value *I32Val = Builder.CreateCall(RetI32Func);
1325   Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc);
1326   Value *StructVal = Builder.CreateCall(RetStructFunc);
1327   Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc);
1328 
1329   Instruction *Internal;
1330   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1331     IRBuilder<>::InsertPointGuard Guard(Builder);
1332     Builder.restoreIP(CodeGenIP);
1333     Internal = Builder.CreateCall(TakeI32Func, I32Val);
1334     Builder.CreateCall(TakeI32PtrFunc, I32PtrVal);
1335     Builder.CreateCall(TakeStructFunc, StructVal);
1336     Builder.CreateCall(TakeStructPtrFunc, StructPtrVal);
1337     return Error::success();
1338   };
1339   auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1340                     Value &Inner, Value *&ReplacementValue) {
1341     ReplacementValue = &Inner;
1342     return CodeGenIP;
1343   };
1344   auto FiniCB = [](InsertPointTy) { return Error::success(); };
1345 
1346   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
1347                                     F->getEntryBlock().getFirstInsertionPt());
1348   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
1349       OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
1350                                 nullptr, nullptr, OMP_PROC_BIND_default, false);
1351   assert(AfterIP && "unexpected error");
1352   Builder.restoreIP(*AfterIP);
1353   Builder.CreateRetVoid();
1354 
1355   OMPBuilder.finalize();
1356 
1357   EXPECT_FALSE(verifyModule(*M, &errs()));
1358   Function *OutlinedFn = Internal->getFunction();
1359 
1360   Type *Arg2Type = OutlinedFn->getArg(2)->getType();
1361   EXPECT_TRUE(Arg2Type->isPointerTy());
1362 }
1363 
1364 TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
1365   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1366   OpenMPIRBuilder OMPBuilder(*M);
1367   OMPBuilder.initialize();
1368   IRBuilder<> Builder(BB);
1369   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1370   Value *TripCount = F->getArg(0);
1371 
1372   unsigned NumBodiesGenerated = 0;
1373   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
1374     NumBodiesGenerated += 1;
1375 
1376     Builder.restoreIP(CodeGenIP);
1377 
1378     Value *Cmp = Builder.CreateICmpEQ(LC, TripCount);
1379     Instruction *ThenTerm, *ElseTerm;
1380     SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
1381                                   &ThenTerm, &ElseTerm);
1382     return Error::success();
1383   };
1384 
1385   Expected<CanonicalLoopInfo *> LoopResult =
1386       OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, TripCount);
1387   assert(LoopResult && "unexpected error");
1388   CanonicalLoopInfo *Loop = *LoopResult;
1389 
1390   Builder.restoreIP(Loop->getAfterIP());
1391   ReturnInst *RetInst = Builder.CreateRetVoid();
1392   OMPBuilder.finalize();
1393 
1394   Loop->assertOK();
1395   EXPECT_FALSE(verifyModule(*M, &errs()));
1396 
1397   EXPECT_EQ(NumBodiesGenerated, 1U);
1398 
1399   // Verify control flow structure (in addition to Loop->assertOK()).
1400   EXPECT_EQ(Loop->getPreheader()->getSinglePredecessor(), &F->getEntryBlock());
1401   EXPECT_EQ(Loop->getAfter(), Builder.GetInsertBlock());
1402 
1403   Instruction *IndVar = Loop->getIndVar();
1404   EXPECT_TRUE(isa<PHINode>(IndVar));
1405   EXPECT_EQ(IndVar->getType(), TripCount->getType());
1406   EXPECT_EQ(IndVar->getParent(), Loop->getHeader());
1407 
1408   EXPECT_EQ(Loop->getTripCount(), TripCount);
1409 
1410   BasicBlock *Body = Loop->getBody();
1411   Instruction *CmpInst = &Body->front();
1412   EXPECT_TRUE(isa<ICmpInst>(CmpInst));
1413   EXPECT_EQ(CmpInst->getOperand(0), IndVar);
1414 
1415   BasicBlock *LatchPred = Loop->getLatch()->getSinglePredecessor();
1416   EXPECT_TRUE(llvm::all_of(successors(Body), [=](BasicBlock *SuccBB) {
1417     return SuccBB->getSingleSuccessor() == LatchPred;
1418   }));
1419 
1420   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
1421 }
1422 
1423 TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
1424   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1425   OpenMPIRBuilder OMPBuilder(*M);
1426   OMPBuilder.initialize();
1427   IRBuilder<> Builder(BB);
1428 
1429   // Check the trip count is computed correctly. We generate the canonical loop
1430   // but rely on the IRBuilder's constant folder to compute the final result
1431   // since all inputs are constant. To verify overflow situations, limit the
1432   // trip count / loop counter widths to 16 bits.
1433   auto EvalTripCount = [&](int64_t Start, int64_t Stop, int64_t Step,
1434                            bool IsSigned, bool InclusiveStop) -> int64_t {
1435     OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1436     Type *LCTy = Type::getInt16Ty(Ctx);
1437     Value *StartVal = ConstantInt::get(LCTy, Start);
1438     Value *StopVal = ConstantInt::get(LCTy, Stop);
1439     Value *StepVal = ConstantInt::get(LCTy, Step);
1440     auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
1441       return Error::success();
1442     };
1443     Expected<CanonicalLoopInfo *> LoopResult =
1444         OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
1445                                        StepVal, IsSigned, InclusiveStop);
1446     assert(LoopResult && "unexpected error");
1447     CanonicalLoopInfo *Loop = *LoopResult;
1448     Loop->assertOK();
1449     Builder.restoreIP(Loop->getAfterIP());
1450     Value *TripCount = Loop->getTripCount();
1451     return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
1452   };
1453 
1454   EXPECT_EQ(EvalTripCount(0, 0, 1, false, false), 0);
1455   EXPECT_EQ(EvalTripCount(0, 1, 2, false, false), 1);
1456   EXPECT_EQ(EvalTripCount(0, 42, 1, false, false), 42);
1457   EXPECT_EQ(EvalTripCount(0, 42, 2, false, false), 21);
1458   EXPECT_EQ(EvalTripCount(21, 42, 1, false, false), 21);
1459   EXPECT_EQ(EvalTripCount(0, 5, 5, false, false), 1);
1460   EXPECT_EQ(EvalTripCount(0, 9, 5, false, false), 2);
1461   EXPECT_EQ(EvalTripCount(0, 11, 5, false, false), 3);
1462   EXPECT_EQ(EvalTripCount(0, 0xFFFF, 1, false, false), 0xFFFF);
1463   EXPECT_EQ(EvalTripCount(0xFFFF, 0, 1, false, false), 0);
1464   EXPECT_EQ(EvalTripCount(0xFFFE, 0xFFFF, 1, false, false), 1);
1465   EXPECT_EQ(EvalTripCount(0, 0xFFFF, 0x100, false, false), 0x100);
1466   EXPECT_EQ(EvalTripCount(0, 0xFFFF, 0xFFFF, false, false), 1);
1467 
1468   EXPECT_EQ(EvalTripCount(0, 6, 5, false, false), 2);
1469   EXPECT_EQ(EvalTripCount(0, 0xFFFF, 0xFFFE, false, false), 2);
1470   EXPECT_EQ(EvalTripCount(0, 0, 1, false, true), 1);
1471   EXPECT_EQ(EvalTripCount(0, 0, 0xFFFF, false, true), 1);
1472   EXPECT_EQ(EvalTripCount(0, 0xFFFE, 1, false, true), 0xFFFF);
1473   EXPECT_EQ(EvalTripCount(0, 0xFFFE, 2, false, true), 0x8000);
1474 
1475   EXPECT_EQ(EvalTripCount(0, 0, -1, true, false), 0);
1476   EXPECT_EQ(EvalTripCount(0, 1, -1, true, true), 0);
1477   EXPECT_EQ(EvalTripCount(20, 5, -5, true, false), 3);
1478   EXPECT_EQ(EvalTripCount(20, 5, -5, true, true), 4);
1479   EXPECT_EQ(EvalTripCount(-4, -2, 2, true, false), 1);
1480   EXPECT_EQ(EvalTripCount(-4, -3, 2, true, false), 1);
1481   EXPECT_EQ(EvalTripCount(-4, -2, 2, true, true), 2);
1482 
1483   EXPECT_EQ(EvalTripCount(INT16_MIN, 0, 1, true, false), 0x8000);
1484   EXPECT_EQ(EvalTripCount(INT16_MIN, 0, 1, true, true), 0x8001);
1485   EXPECT_EQ(EvalTripCount(INT16_MIN, 0x7FFF, 1, true, false), 0xFFFF);
1486   EXPECT_EQ(EvalTripCount(INT16_MIN + 1, 0x7FFF, 1, true, true), 0xFFFF);
1487   EXPECT_EQ(EvalTripCount(INT16_MIN, 0, 0x7FFF, true, false), 2);
1488   EXPECT_EQ(EvalTripCount(0x7FFF, 0, -1, true, false), 0x7FFF);
1489   EXPECT_EQ(EvalTripCount(0, INT16_MIN, -1, true, false), 0x8000);
1490   EXPECT_EQ(EvalTripCount(0, INT16_MIN, -16, true, false), 0x800);
1491   EXPECT_EQ(EvalTripCount(0x7FFF, INT16_MIN, -1, true, false), 0xFFFF);
1492   EXPECT_EQ(EvalTripCount(0x7FFF, 1, INT16_MIN, true, false), 1);
1493   EXPECT_EQ(EvalTripCount(0x7FFF, -1, INT16_MIN, true, true), 2);
1494 
1495   // Finalize the function and verify it.
1496   Builder.CreateRetVoid();
1497   OMPBuilder.finalize();
1498   EXPECT_FALSE(verifyModule(*M, &errs()));
1499 }
1500 
1501 TEST_F(OpenMPIRBuilderTest, CollapseNestedLoops) {
1502   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1503   OpenMPIRBuilder OMPBuilder(*M);
1504   OMPBuilder.initialize();
1505   F->setName("func");
1506 
1507   IRBuilder<> Builder(BB);
1508 
1509   Type *LCTy = F->getArg(0)->getType();
1510   Constant *One = ConstantInt::get(LCTy, 1);
1511   Constant *Two = ConstantInt::get(LCTy, 2);
1512   Value *OuterTripCount =
1513       Builder.CreateAdd(F->getArg(0), Two, "tripcount.outer");
1514   Value *InnerTripCount =
1515       Builder.CreateAdd(F->getArg(0), One, "tripcount.inner");
1516 
1517   // Fix an insertion point for ComputeIP.
1518   BasicBlock *LoopNextEnter =
1519       BasicBlock::Create(M->getContext(), "loopnest.enter", F,
1520                          Builder.GetInsertBlock()->getNextNode());
1521   BranchInst *EnterBr = Builder.CreateBr(LoopNextEnter);
1522   InsertPointTy ComputeIP{EnterBr->getParent(), EnterBr->getIterator()};
1523 
1524   Builder.SetInsertPoint(LoopNextEnter);
1525   OpenMPIRBuilder::LocationDescription OuterLoc(Builder.saveIP(), DL);
1526 
1527   CanonicalLoopInfo *InnerLoop = nullptr;
1528   CallInst *InbetweenLead = nullptr;
1529   CallInst *InbetweenTrail = nullptr;
1530   CallInst *Call = nullptr;
1531   auto OuterLoopBodyGenCB = [&](InsertPointTy OuterCodeGenIP, Value *OuterLC) {
1532     Builder.restoreIP(OuterCodeGenIP);
1533     InbetweenLead =
1534         createPrintfCall(Builder, "In-between lead i=%d\\n", {OuterLC});
1535 
1536     auto InnerLoopBodyGenCB = [&](InsertPointTy InnerCodeGenIP,
1537                                   Value *InnerLC) {
1538       Builder.restoreIP(InnerCodeGenIP);
1539       Call = createPrintfCall(Builder, "body i=%d j=%d\\n", {OuterLC, InnerLC});
1540       return Error::success();
1541     };
1542     Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1543         Builder.saveIP(), InnerLoopBodyGenCB, InnerTripCount, "inner");
1544     assert(LoopResult && "unexpected error");
1545     InnerLoop = *LoopResult;
1546 
1547     Builder.restoreIP(InnerLoop->getAfterIP());
1548     InbetweenTrail =
1549         createPrintfCall(Builder, "In-between trail i=%d\\n", {OuterLC});
1550     return Error::success();
1551   };
1552   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1553       OuterLoc, OuterLoopBodyGenCB, OuterTripCount, "outer");
1554   assert(LoopResult && "unexpected error");
1555   CanonicalLoopInfo *OuterLoop = *LoopResult;
1556 
1557   // Finish the function.
1558   Builder.restoreIP(OuterLoop->getAfterIP());
1559   Builder.CreateRetVoid();
1560 
1561   CanonicalLoopInfo *Collapsed =
1562       OMPBuilder.collapseLoops(DL, {OuterLoop, InnerLoop}, ComputeIP);
1563 
1564   OMPBuilder.finalize();
1565   EXPECT_FALSE(verifyModule(*M, &errs()));
1566 
1567   // Verify control flow and BB order.
1568   BasicBlock *RefOrder[] = {
1569       Collapsed->getPreheader(),   Collapsed->getHeader(),
1570       Collapsed->getCond(),        Collapsed->getBody(),
1571       InbetweenLead->getParent(),  Call->getParent(),
1572       InbetweenTrail->getParent(), Collapsed->getLatch(),
1573       Collapsed->getExit(),        Collapsed->getAfter(),
1574   };
1575   EXPECT_TRUE(verifyDFSOrder(F, RefOrder));
1576   EXPECT_TRUE(verifyListOrder(F, RefOrder));
1577 
1578   // Verify the total trip count.
1579   auto *TripCount = cast<MulOperator>(Collapsed->getTripCount());
1580   EXPECT_EQ(TripCount->getOperand(0), OuterTripCount);
1581   EXPECT_EQ(TripCount->getOperand(1), InnerTripCount);
1582 
1583   // Verify the changed indvar.
1584   auto *OuterIV = cast<BinaryOperator>(Call->getOperand(1));
1585   EXPECT_EQ(OuterIV->getOpcode(), Instruction::UDiv);
1586   EXPECT_EQ(OuterIV->getParent(), Collapsed->getBody());
1587   EXPECT_EQ(OuterIV->getOperand(1), InnerTripCount);
1588   EXPECT_EQ(OuterIV->getOperand(0), Collapsed->getIndVar());
1589 
1590   auto *InnerIV = cast<BinaryOperator>(Call->getOperand(2));
1591   EXPECT_EQ(InnerIV->getOpcode(), Instruction::URem);
1592   EXPECT_EQ(InnerIV->getParent(), Collapsed->getBody());
1593   EXPECT_EQ(InnerIV->getOperand(0), Collapsed->getIndVar());
1594   EXPECT_EQ(InnerIV->getOperand(1), InnerTripCount);
1595 
1596   EXPECT_EQ(InbetweenLead->getOperand(1), OuterIV);
1597   EXPECT_EQ(InbetweenTrail->getOperand(1), OuterIV);
1598 }
1599 
1600 TEST_F(OpenMPIRBuilderTest, TileSingleLoop) {
1601   OpenMPIRBuilder OMPBuilder(*M);
1602   CallInst *Call;
1603   BasicBlock *BodyCode;
1604   CanonicalLoopInfo *Loop =
1605       buildSingleLoopFunction(DL, OMPBuilder, 32, &Call, &BodyCode);
1606 
1607   Instruction *OrigIndVar = Loop->getIndVar();
1608   EXPECT_EQ(Call->getOperand(1), OrigIndVar);
1609 
1610   // Tile the loop.
1611   Constant *TileSize = ConstantInt::get(Loop->getIndVarType(), APInt(32, 7));
1612   std::vector<CanonicalLoopInfo *> GenLoops =
1613       OMPBuilder.tileLoops(DL, {Loop}, {TileSize});
1614 
1615   OMPBuilder.finalize();
1616   EXPECT_FALSE(verifyModule(*M, &errs()));
1617 
1618   EXPECT_EQ(GenLoops.size(), 2u);
1619   CanonicalLoopInfo *Floor = GenLoops[0];
1620   CanonicalLoopInfo *Tile = GenLoops[1];
1621 
1622   BasicBlock *RefOrder[] = {
1623       Floor->getPreheader(), Floor->getHeader(),   Floor->getCond(),
1624       Floor->getBody(),      Tile->getPreheader(), Tile->getHeader(),
1625       Tile->getCond(),       Tile->getBody(),      BodyCode,
1626       Tile->getLatch(),      Tile->getExit(),      Tile->getAfter(),
1627       Floor->getLatch(),     Floor->getExit(),     Floor->getAfter(),
1628   };
1629   EXPECT_TRUE(verifyDFSOrder(F, RefOrder));
1630   EXPECT_TRUE(verifyListOrder(F, RefOrder));
1631 
1632   // Check the induction variable.
1633   EXPECT_EQ(Call->getParent(), BodyCode);
1634   auto *Shift = cast<AddOperator>(Call->getOperand(1));
1635   EXPECT_EQ(cast<Instruction>(Shift)->getParent(), Tile->getBody());
1636   EXPECT_EQ(Shift->getOperand(1), Tile->getIndVar());
1637   auto *Scale = cast<MulOperator>(Shift->getOperand(0));
1638   EXPECT_EQ(cast<Instruction>(Scale)->getParent(), Tile->getBody());
1639   EXPECT_EQ(Scale->getOperand(0), TileSize);
1640   EXPECT_EQ(Scale->getOperand(1), Floor->getIndVar());
1641 }
1642 
1643 TEST_F(OpenMPIRBuilderTest, TileNestedLoops) {
1644   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1645   OpenMPIRBuilder OMPBuilder(*M);
1646   OMPBuilder.initialize();
1647   F->setName("func");
1648 
1649   IRBuilder<> Builder(BB);
1650   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
1651   Value *TripCount = F->getArg(0);
1652   Type *LCTy = TripCount->getType();
1653 
1654   BasicBlock *BodyCode = nullptr;
1655   CanonicalLoopInfo *InnerLoop = nullptr;
1656   auto OuterLoopBodyGenCB = [&](InsertPointTy OuterCodeGenIP,
1657                                 llvm::Value *OuterLC) {
1658     auto InnerLoopBodyGenCB = [&](InsertPointTy InnerCodeGenIP,
1659                                   llvm::Value *InnerLC) {
1660       Builder.restoreIP(InnerCodeGenIP);
1661       BodyCode = Builder.GetInsertBlock();
1662 
1663       // Add something that consumes the induction variables to the body.
1664       createPrintfCall(Builder, "i=%d j=%d\\n", {OuterLC, InnerLC});
1665       return Error::success();
1666     };
1667     Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1668         OuterCodeGenIP, InnerLoopBodyGenCB, TripCount, "inner");
1669     assert(LoopResult && "unexpected error");
1670     InnerLoop = *LoopResult;
1671     return Error::success();
1672   };
1673   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1674       Loc, OuterLoopBodyGenCB, TripCount, "outer");
1675   assert(LoopResult && "unexpected error");
1676   CanonicalLoopInfo *OuterLoop = *LoopResult;
1677 
1678   // Finalize the function.
1679   Builder.restoreIP(OuterLoop->getAfterIP());
1680   Builder.CreateRetVoid();
1681 
1682   // Tile to loop nest.
1683   Constant *OuterTileSize = ConstantInt::get(LCTy, APInt(32, 11));
1684   Constant *InnerTileSize = ConstantInt::get(LCTy, APInt(32, 7));
1685   std::vector<CanonicalLoopInfo *> GenLoops = OMPBuilder.tileLoops(
1686       DL, {OuterLoop, InnerLoop}, {OuterTileSize, InnerTileSize});
1687 
1688   OMPBuilder.finalize();
1689   EXPECT_FALSE(verifyModule(*M, &errs()));
1690 
1691   EXPECT_EQ(GenLoops.size(), 4u);
1692   CanonicalLoopInfo *Floor1 = GenLoops[0];
1693   CanonicalLoopInfo *Floor2 = GenLoops[1];
1694   CanonicalLoopInfo *Tile1 = GenLoops[2];
1695   CanonicalLoopInfo *Tile2 = GenLoops[3];
1696 
1697   BasicBlock *RefOrder[] = {
1698       Floor1->getPreheader(),
1699       Floor1->getHeader(),
1700       Floor1->getCond(),
1701       Floor1->getBody(),
1702       Floor2->getPreheader(),
1703       Floor2->getHeader(),
1704       Floor2->getCond(),
1705       Floor2->getBody(),
1706       Tile1->getPreheader(),
1707       Tile1->getHeader(),
1708       Tile1->getCond(),
1709       Tile1->getBody(),
1710       Tile2->getPreheader(),
1711       Tile2->getHeader(),
1712       Tile2->getCond(),
1713       Tile2->getBody(),
1714       BodyCode,
1715       Tile2->getLatch(),
1716       Tile2->getExit(),
1717       Tile2->getAfter(),
1718       Tile1->getLatch(),
1719       Tile1->getExit(),
1720       Tile1->getAfter(),
1721       Floor2->getLatch(),
1722       Floor2->getExit(),
1723       Floor2->getAfter(),
1724       Floor1->getLatch(),
1725       Floor1->getExit(),
1726       Floor1->getAfter(),
1727   };
1728   EXPECT_TRUE(verifyDFSOrder(F, RefOrder));
1729   EXPECT_TRUE(verifyListOrder(F, RefOrder));
1730 }
1731 
1732 TEST_F(OpenMPIRBuilderTest, TileNestedLoopsWithBounds) {
1733   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1734   OpenMPIRBuilder OMPBuilder(*M);
1735   OMPBuilder.initialize();
1736   F->setName("func");
1737 
1738   IRBuilder<> Builder(BB);
1739   Value *TripCount = F->getArg(0);
1740   Type *LCTy = TripCount->getType();
1741 
1742   Value *OuterStartVal = ConstantInt::get(LCTy, 2);
1743   Value *OuterStopVal = TripCount;
1744   Value *OuterStep = ConstantInt::get(LCTy, 5);
1745   Value *InnerStartVal = ConstantInt::get(LCTy, 13);
1746   Value *InnerStopVal = TripCount;
1747   Value *InnerStep = ConstantInt::get(LCTy, 3);
1748 
1749   // Fix an insertion point for ComputeIP.
1750   BasicBlock *LoopNextEnter =
1751       BasicBlock::Create(M->getContext(), "loopnest.enter", F,
1752                          Builder.GetInsertBlock()->getNextNode());
1753   BranchInst *EnterBr = Builder.CreateBr(LoopNextEnter);
1754   InsertPointTy ComputeIP{EnterBr->getParent(), EnterBr->getIterator()};
1755 
1756   InsertPointTy LoopIP{LoopNextEnter, LoopNextEnter->begin()};
1757   OpenMPIRBuilder::LocationDescription Loc({LoopIP, DL});
1758 
1759   BasicBlock *BodyCode = nullptr;
1760   CanonicalLoopInfo *InnerLoop = nullptr;
1761   CallInst *Call = nullptr;
1762   auto OuterLoopBodyGenCB = [&](InsertPointTy OuterCodeGenIP,
1763                                 llvm::Value *OuterLC) {
1764     auto InnerLoopBodyGenCB = [&](InsertPointTy InnerCodeGenIP,
1765                                   llvm::Value *InnerLC) {
1766       Builder.restoreIP(InnerCodeGenIP);
1767       BodyCode = Builder.GetInsertBlock();
1768 
1769       // Add something that consumes the induction variable to the body.
1770       Call = createPrintfCall(Builder, "i=%d j=%d\\n", {OuterLC, InnerLC});
1771       return Error::success();
1772     };
1773     Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1774         OuterCodeGenIP, InnerLoopBodyGenCB, InnerStartVal, InnerStopVal,
1775         InnerStep, false, false, ComputeIP, "inner");
1776     assert(LoopResult && "unexpected error");
1777     InnerLoop = *LoopResult;
1778     return Error::success();
1779   };
1780   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
1781       Loc, OuterLoopBodyGenCB, OuterStartVal, OuterStopVal, OuterStep, false,
1782       false, ComputeIP, "outer");
1783   assert(LoopResult && "unexpected error");
1784   CanonicalLoopInfo *OuterLoop = *LoopResult;
1785 
1786   // Finalize the function
1787   Builder.restoreIP(OuterLoop->getAfterIP());
1788   Builder.CreateRetVoid();
1789 
1790   // Tile the loop nest.
1791   Constant *TileSize0 = ConstantInt::get(LCTy, APInt(32, 11));
1792   Constant *TileSize1 = ConstantInt::get(LCTy, APInt(32, 7));
1793   std::vector<CanonicalLoopInfo *> GenLoops =
1794       OMPBuilder.tileLoops(DL, {OuterLoop, InnerLoop}, {TileSize0, TileSize1});
1795 
1796   OMPBuilder.finalize();
1797   EXPECT_FALSE(verifyModule(*M, &errs()));
1798 
1799   EXPECT_EQ(GenLoops.size(), 4u);
1800   CanonicalLoopInfo *Floor0 = GenLoops[0];
1801   CanonicalLoopInfo *Floor1 = GenLoops[1];
1802   CanonicalLoopInfo *Tile0 = GenLoops[2];
1803   CanonicalLoopInfo *Tile1 = GenLoops[3];
1804 
1805   BasicBlock *RefOrder[] = {
1806       Floor0->getPreheader(),
1807       Floor0->getHeader(),
1808       Floor0->getCond(),
1809       Floor0->getBody(),
1810       Floor1->getPreheader(),
1811       Floor1->getHeader(),
1812       Floor1->getCond(),
1813       Floor1->getBody(),
1814       Tile0->getPreheader(),
1815       Tile0->getHeader(),
1816       Tile0->getCond(),
1817       Tile0->getBody(),
1818       Tile1->getPreheader(),
1819       Tile1->getHeader(),
1820       Tile1->getCond(),
1821       Tile1->getBody(),
1822       BodyCode,
1823       Tile1->getLatch(),
1824       Tile1->getExit(),
1825       Tile1->getAfter(),
1826       Tile0->getLatch(),
1827       Tile0->getExit(),
1828       Tile0->getAfter(),
1829       Floor1->getLatch(),
1830       Floor1->getExit(),
1831       Floor1->getAfter(),
1832       Floor0->getLatch(),
1833       Floor0->getExit(),
1834       Floor0->getAfter(),
1835   };
1836   EXPECT_TRUE(verifyDFSOrder(F, RefOrder));
1837   EXPECT_TRUE(verifyListOrder(F, RefOrder));
1838 
1839   EXPECT_EQ(Call->getParent(), BodyCode);
1840 
1841   auto *RangeShift0 = cast<AddOperator>(Call->getOperand(1));
1842   EXPECT_EQ(RangeShift0->getOperand(1), OuterStartVal);
1843   auto *RangeScale0 = cast<MulOperator>(RangeShift0->getOperand(0));
1844   EXPECT_EQ(RangeScale0->getOperand(1), OuterStep);
1845   auto *TileShift0 = cast<AddOperator>(RangeScale0->getOperand(0));
1846   EXPECT_EQ(cast<Instruction>(TileShift0)->getParent(), Tile1->getBody());
1847   EXPECT_EQ(TileShift0->getOperand(1), Tile0->getIndVar());
1848   auto *TileScale0 = cast<MulOperator>(TileShift0->getOperand(0));
1849   EXPECT_EQ(cast<Instruction>(TileScale0)->getParent(), Tile1->getBody());
1850   EXPECT_EQ(TileScale0->getOperand(0), TileSize0);
1851   EXPECT_EQ(TileScale0->getOperand(1), Floor0->getIndVar());
1852 
1853   auto *RangeShift1 = cast<AddOperator>(Call->getOperand(2));
1854   EXPECT_EQ(cast<Instruction>(RangeShift1)->getParent(), BodyCode);
1855   EXPECT_EQ(RangeShift1->getOperand(1), InnerStartVal);
1856   auto *RangeScale1 = cast<MulOperator>(RangeShift1->getOperand(0));
1857   EXPECT_EQ(cast<Instruction>(RangeScale1)->getParent(), BodyCode);
1858   EXPECT_EQ(RangeScale1->getOperand(1), InnerStep);
1859   auto *TileShift1 = cast<AddOperator>(RangeScale1->getOperand(0));
1860   EXPECT_EQ(cast<Instruction>(TileShift1)->getParent(), Tile1->getBody());
1861   EXPECT_EQ(TileShift1->getOperand(1), Tile1->getIndVar());
1862   auto *TileScale1 = cast<MulOperator>(TileShift1->getOperand(0));
1863   EXPECT_EQ(cast<Instruction>(TileScale1)->getParent(), Tile1->getBody());
1864   EXPECT_EQ(TileScale1->getOperand(0), TileSize1);
1865   EXPECT_EQ(TileScale1->getOperand(1), Floor1->getIndVar());
1866 }
1867 
1868 TEST_F(OpenMPIRBuilderTest, TileSingleLoopCounts) {
1869   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1870   OpenMPIRBuilder OMPBuilder(*M);
1871   OMPBuilder.initialize();
1872   IRBuilder<> Builder(BB);
1873 
1874   // Create a loop, tile it, and extract its trip count. All input values are
1875   // constant and IRBuilder evaluates all-constant arithmetic inplace, such that
1876   // the floor trip count itself will be a ConstantInt. Unfortunately we cannot
1877   // do the same for the tile loop.
1878   auto GetFloorCount = [&](int64_t Start, int64_t Stop, int64_t Step,
1879                            bool IsSigned, bool InclusiveStop,
1880                            int64_t TileSize) -> uint64_t {
1881     OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
1882     Type *LCTy = Type::getInt16Ty(Ctx);
1883     Value *StartVal = ConstantInt::get(LCTy, Start);
1884     Value *StopVal = ConstantInt::get(LCTy, Stop);
1885     Value *StepVal = ConstantInt::get(LCTy, Step);
1886 
1887     // Generate a loop.
1888     auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
1889       return Error::success();
1890     };
1891     Expected<CanonicalLoopInfo *> LoopResult =
1892         OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
1893                                        StepVal, IsSigned, InclusiveStop);
1894     assert(LoopResult && "unexpected error");
1895     CanonicalLoopInfo *Loop = *LoopResult;
1896     InsertPointTy AfterIP = Loop->getAfterIP();
1897 
1898     // Tile the loop.
1899     Value *TileSizeVal = ConstantInt::get(LCTy, TileSize);
1900     std::vector<CanonicalLoopInfo *> GenLoops =
1901         OMPBuilder.tileLoops(Loc.DL, {Loop}, {TileSizeVal});
1902 
1903     // Set the insertion pointer to after loop, where the next loop will be
1904     // emitted.
1905     Builder.restoreIP(AfterIP);
1906 
1907     // Extract the trip count.
1908     CanonicalLoopInfo *FloorLoop = GenLoops[0];
1909     Value *FloorTripCount = FloorLoop->getTripCount();
1910     return cast<ConstantInt>(FloorTripCount)->getValue().getZExtValue();
1911   };
1912 
1913   // Empty iteration domain.
1914   EXPECT_EQ(GetFloorCount(0, 0, 1, false, false, 7), 0u);
1915   EXPECT_EQ(GetFloorCount(0, -1, 1, false, true, 7), 0u);
1916   EXPECT_EQ(GetFloorCount(-1, -1, -1, true, false, 7), 0u);
1917   EXPECT_EQ(GetFloorCount(-1, 0, -1, true, true, 7), 0u);
1918   EXPECT_EQ(GetFloorCount(-1, -1, 3, true, false, 7), 0u);
1919 
1920   // Only complete tiles.
1921   EXPECT_EQ(GetFloorCount(0, 14, 1, false, false, 7), 2u);
1922   EXPECT_EQ(GetFloorCount(0, 14, 1, false, false, 7), 2u);
1923   EXPECT_EQ(GetFloorCount(1, 15, 1, false, false, 7), 2u);
1924   EXPECT_EQ(GetFloorCount(0, -14, -1, true, false, 7), 2u);
1925   EXPECT_EQ(GetFloorCount(-1, -14, -1, true, true, 7), 2u);
1926   EXPECT_EQ(GetFloorCount(0, 3 * 7 * 2, 3, false, false, 7), 2u);
1927 
1928   // Only a partial tile.
1929   EXPECT_EQ(GetFloorCount(0, 1, 1, false, false, 7), 1u);
1930   EXPECT_EQ(GetFloorCount(0, 6, 1, false, false, 7), 1u);
1931   EXPECT_EQ(GetFloorCount(-1, 1, 3, true, false, 7), 1u);
1932   EXPECT_EQ(GetFloorCount(-1, -2, -1, true, false, 7), 1u);
1933   EXPECT_EQ(GetFloorCount(0, 2, 3, false, false, 7), 1u);
1934 
1935   // Complete and partial tiles.
1936   EXPECT_EQ(GetFloorCount(0, 13, 1, false, false, 7), 2u);
1937   EXPECT_EQ(GetFloorCount(0, 15, 1, false, false, 7), 3u);
1938   EXPECT_EQ(GetFloorCount(-1, -14, -1, true, false, 7), 2u);
1939   EXPECT_EQ(GetFloorCount(0, 3 * 7 * 5 - 1, 3, false, false, 7), 5u);
1940   EXPECT_EQ(GetFloorCount(-1, -3 * 7 * 5, -3, true, false, 7), 5u);
1941 
1942   // Close to 16-bit integer range.
1943   EXPECT_EQ(GetFloorCount(0, 0xFFFF, 1, false, false, 1), 0xFFFFu);
1944   EXPECT_EQ(GetFloorCount(0, 0xFFFF, 1, false, false, 7), 0xFFFFu / 7 + 1);
1945   EXPECT_EQ(GetFloorCount(0, 0xFFFE, 1, false, true, 7), 0xFFFFu / 7 + 1);
1946   EXPECT_EQ(GetFloorCount(-0x8000, 0x7FFF, 1, true, false, 7), 0xFFFFu / 7 + 1);
1947   EXPECT_EQ(GetFloorCount(-0x7FFF, 0x7FFF, 1, true, true, 7), 0xFFFFu / 7 + 1);
1948   EXPECT_EQ(GetFloorCount(0, 0xFFFE, 1, false, false, 0xFFFF), 1u);
1949   EXPECT_EQ(GetFloorCount(-0x8000, 0x7FFF, 1, true, false, 0xFFFF), 1u);
1950 
1951   // Finalize the function.
1952   Builder.CreateRetVoid();
1953   OMPBuilder.finalize();
1954 
1955   EXPECT_FALSE(verifyModule(*M, &errs()));
1956 }
1957 
1958 TEST_F(OpenMPIRBuilderTest, ApplySimd) {
1959   OpenMPIRBuilder OMPBuilder(*M);
1960   MapVector<Value *, Value *> AlignedVars;
1961   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
1962 
1963   // Simd-ize the loop.
1964   OMPBuilder.applySimd(CLI, AlignedVars, /* IfCond */ nullptr,
1965                        OrderKind::OMP_ORDER_unknown,
1966                        /* Simdlen */ nullptr,
1967                        /* Safelen */ nullptr);
1968 
1969   OMPBuilder.finalize();
1970   EXPECT_FALSE(verifyModule(*M, &errs()));
1971 
1972   PassBuilder PB;
1973   FunctionAnalysisManager FAM;
1974   PB.registerFunctionAnalyses(FAM);
1975   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
1976 
1977   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
1978   EXPECT_EQ(TopLvl.size(), 1u);
1979 
1980   Loop *L = TopLvl.front();
1981   EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
1982   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
1983 
1984   // Check for llvm.access.group metadata attached to the printf
1985   // function in the loop body.
1986   BasicBlock *LoopBody = CLI->getBody();
1987   EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
1988     return I.getMetadata("llvm.access.group") != nullptr;
1989   }));
1990 }
1991 
1992 TEST_F(OpenMPIRBuilderTest, ApplySimdCustomAligned) {
1993   OpenMPIRBuilder OMPBuilder(*M);
1994   IRBuilder<> Builder(BB);
1995   const int AlignmentValue = 32;
1996   AllocaInst *Alloc1 =
1997       Builder.CreateAlloca(Builder.getPtrTy(), Builder.getInt64(1));
1998   LoadInst *Load1 = Builder.CreateLoad(Alloc1->getAllocatedType(), Alloc1);
1999   MapVector<Value *, Value *> AlignedVars;
2000   AlignedVars.insert({Load1, Builder.getInt64(AlignmentValue)});
2001 
2002   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2003 
2004   // Simd-ize the loop.
2005   OMPBuilder.applySimd(CLI, AlignedVars, /* IfCond */ nullptr,
2006                        OrderKind::OMP_ORDER_unknown,
2007                        /* Simdlen */ nullptr,
2008                        /* Safelen */ nullptr);
2009 
2010   OMPBuilder.finalize();
2011   EXPECT_FALSE(verifyModule(*M, &errs()));
2012 
2013   PassBuilder PB;
2014   FunctionAnalysisManager FAM;
2015   PB.registerFunctionAnalyses(FAM);
2016   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2017 
2018   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2019   EXPECT_EQ(TopLvl.size(), 1u);
2020 
2021   Loop *L = TopLvl.front();
2022   EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2023   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2024 
2025   // Check for llvm.access.group metadata attached to the printf
2026   // function in the loop body.
2027   BasicBlock *LoopBody = CLI->getBody();
2028   EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
2029     return I.getMetadata("llvm.access.group") != nullptr;
2030   }));
2031 
2032   // Check if number of assumption instructions is equal to number of aligned
2033   // variables
2034   BasicBlock *LoopPreheader = CLI->getPreheader();
2035   size_t NumAssummptionCallsInPreheader = count_if(
2036       *LoopPreheader, [](Instruction &I) { return isa<AssumeInst>(I); });
2037   EXPECT_EQ(NumAssummptionCallsInPreheader, AlignedVars.size());
2038 
2039   // Check if variables are correctly aligned
2040   for (Instruction &Instr : *LoopPreheader) {
2041     if (!isa<AssumeInst>(Instr))
2042       continue;
2043     AssumeInst *AssumeInstruction = cast<AssumeInst>(&Instr);
2044     if (AssumeInstruction->getNumTotalBundleOperands()) {
2045       auto Bundle = AssumeInstruction->getOperandBundleAt(0);
2046       if (Bundle.getTagName() == "align") {
2047         EXPECT_TRUE(isa<ConstantInt>(Bundle.Inputs[1]));
2048         auto ConstIntVal = dyn_cast<ConstantInt>(Bundle.Inputs[1]);
2049         EXPECT_EQ(ConstIntVal->getSExtValue(), AlignmentValue);
2050       }
2051     }
2052   }
2053 }
2054 TEST_F(OpenMPIRBuilderTest, ApplySimdlen) {
2055   OpenMPIRBuilder OMPBuilder(*M);
2056   MapVector<Value *, Value *> AlignedVars;
2057   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2058 
2059   // Simd-ize the loop.
2060   OMPBuilder.applySimd(CLI, AlignedVars,
2061                        /* IfCond */ nullptr, OrderKind::OMP_ORDER_unknown,
2062                        ConstantInt::get(Type::getInt32Ty(Ctx), 3),
2063                        /* Safelen */ nullptr);
2064 
2065   OMPBuilder.finalize();
2066   EXPECT_FALSE(verifyModule(*M, &errs()));
2067 
2068   PassBuilder PB;
2069   FunctionAnalysisManager FAM;
2070   PB.registerFunctionAnalyses(FAM);
2071   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2072 
2073   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2074   EXPECT_EQ(TopLvl.size(), 1u);
2075 
2076   Loop *L = TopLvl.front();
2077   EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2078   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2079   EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);
2080 
2081   // Check for llvm.access.group metadata attached to the printf
2082   // function in the loop body.
2083   BasicBlock *LoopBody = CLI->getBody();
2084   EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
2085     return I.getMetadata("llvm.access.group") != nullptr;
2086   }));
2087 }
2088 
2089 TEST_F(OpenMPIRBuilderTest, ApplySafelenOrderConcurrent) {
2090   OpenMPIRBuilder OMPBuilder(*M);
2091   MapVector<Value *, Value *> AlignedVars;
2092 
2093   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2094 
2095   // Simd-ize the loop.
2096   OMPBuilder.applySimd(
2097       CLI, AlignedVars, /* IfCond */ nullptr, OrderKind::OMP_ORDER_concurrent,
2098       /* Simdlen */ nullptr, ConstantInt::get(Type::getInt32Ty(Ctx), 3));
2099 
2100   OMPBuilder.finalize();
2101   EXPECT_FALSE(verifyModule(*M, &errs()));
2102 
2103   PassBuilder PB;
2104   FunctionAnalysisManager FAM;
2105   PB.registerFunctionAnalyses(FAM);
2106   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2107 
2108   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2109   EXPECT_EQ(TopLvl.size(), 1u);
2110 
2111   Loop *L = TopLvl.front();
2112   // Parallel metadata shoudl be attached because of presence of
2113   // the order(concurrent) OpenMP clause
2114   EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2115   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2116   EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);
2117 
2118   // Check for llvm.access.group metadata attached to the printf
2119   // function in the loop body.
2120   BasicBlock *LoopBody = CLI->getBody();
2121   EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
2122     return I.getMetadata("llvm.access.group") != nullptr;
2123   }));
2124 }
2125 
2126 TEST_F(OpenMPIRBuilderTest, ApplySafelen) {
2127   OpenMPIRBuilder OMPBuilder(*M);
2128   MapVector<Value *, Value *> AlignedVars;
2129 
2130   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2131 
2132   OMPBuilder.applySimd(
2133       CLI, AlignedVars, /* IfCond */ nullptr, OrderKind::OMP_ORDER_unknown,
2134       /* Simdlen */ nullptr, ConstantInt::get(Type::getInt32Ty(Ctx), 3));
2135 
2136   OMPBuilder.finalize();
2137   EXPECT_FALSE(verifyModule(*M, &errs()));
2138 
2139   PassBuilder PB;
2140   FunctionAnalysisManager FAM;
2141   PB.registerFunctionAnalyses(FAM);
2142   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2143 
2144   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2145   EXPECT_EQ(TopLvl.size(), 1u);
2146 
2147   Loop *L = TopLvl.front();
2148   EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2149   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2150   EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);
2151 
2152   // Check for llvm.access.group metadata attached to the printf
2153   // function in the loop body.
2154   BasicBlock *LoopBody = CLI->getBody();
2155   EXPECT_FALSE(any_of(*LoopBody, [](Instruction &I) {
2156     return I.getMetadata("llvm.access.group") != nullptr;
2157   }));
2158 }
2159 
2160 TEST_F(OpenMPIRBuilderTest, ApplySimdlenSafelen) {
2161   OpenMPIRBuilder OMPBuilder(*M);
2162   MapVector<Value *, Value *> AlignedVars;
2163 
2164   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2165 
2166   OMPBuilder.applySimd(CLI, AlignedVars, /* IfCond */ nullptr,
2167                        OrderKind::OMP_ORDER_unknown,
2168                        ConstantInt::get(Type::getInt32Ty(Ctx), 2),
2169                        ConstantInt::get(Type::getInt32Ty(Ctx), 3));
2170 
2171   OMPBuilder.finalize();
2172   EXPECT_FALSE(verifyModule(*M, &errs()));
2173 
2174   PassBuilder PB;
2175   FunctionAnalysisManager FAM;
2176   PB.registerFunctionAnalyses(FAM);
2177   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2178 
2179   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2180   EXPECT_EQ(TopLvl.size(), 1u);
2181 
2182   Loop *L = TopLvl.front();
2183   EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2184   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2185   EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 2);
2186 
2187   // Check for llvm.access.group metadata attached to the printf
2188   // function in the loop body.
2189   BasicBlock *LoopBody = CLI->getBody();
2190   EXPECT_FALSE(any_of(*LoopBody, [](Instruction &I) {
2191     return I.getMetadata("llvm.access.group") != nullptr;
2192   }));
2193 }
2194 
2195 TEST_F(OpenMPIRBuilderTest, ApplySimdIf) {
2196   OpenMPIRBuilder OMPBuilder(*M);
2197   IRBuilder<> Builder(BB);
2198   MapVector<Value *, Value *> AlignedVars;
2199   AllocaInst *Alloc1 = Builder.CreateAlloca(Builder.getInt32Ty());
2200   AllocaInst *Alloc2 = Builder.CreateAlloca(Builder.getInt32Ty());
2201 
2202   // Generation of if condition
2203   Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), Alloc1);
2204   Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 1U), Alloc2);
2205   LoadInst *Load1 = Builder.CreateLoad(Alloc1->getAllocatedType(), Alloc1);
2206   LoadInst *Load2 = Builder.CreateLoad(Alloc2->getAllocatedType(), Alloc2);
2207 
2208   Value *IfCmp = Builder.CreateICmpNE(Load1, Load2);
2209 
2210   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2211 
2212   // Simd-ize the loop with if condition
2213   OMPBuilder.applySimd(CLI, AlignedVars, IfCmp, OrderKind::OMP_ORDER_unknown,
2214                        ConstantInt::get(Type::getInt32Ty(Ctx), 3),
2215                        /* Safelen */ nullptr);
2216 
2217   OMPBuilder.finalize();
2218   EXPECT_FALSE(verifyModule(*M, &errs()));
2219 
2220   PassBuilder PB;
2221   FunctionAnalysisManager FAM;
2222   PB.registerFunctionAnalyses(FAM);
2223   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2224 
2225   // Check if there are two loops (one with enabled vectorization)
2226   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2227   EXPECT_EQ(TopLvl.size(), 2u);
2228 
2229   Loop *L = TopLvl[0];
2230   EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2231   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2232   EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);
2233 
2234   // The second loop should have disabled vectorization
2235   L = TopLvl[1];
2236   EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
2237   EXPECT_FALSE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
2238   // Check for llvm.access.group metadata attached to the printf
2239   // function in the loop body.
2240   BasicBlock *LoopBody = CLI->getBody();
2241   EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
2242     return I.getMetadata("llvm.access.group") != nullptr;
2243   }));
2244 }
2245 
2246 TEST_F(OpenMPIRBuilderTest, UnrollLoopFull) {
2247   OpenMPIRBuilder OMPBuilder(*M);
2248 
2249   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2250 
2251   // Unroll the loop.
2252   OMPBuilder.unrollLoopFull(DL, CLI);
2253 
2254   OMPBuilder.finalize();
2255   EXPECT_FALSE(verifyModule(*M, &errs()));
2256 
2257   PassBuilder PB;
2258   FunctionAnalysisManager FAM;
2259   PB.registerFunctionAnalyses(FAM);
2260   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2261 
2262   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2263   EXPECT_EQ(TopLvl.size(), 1u);
2264 
2265   Loop *L = TopLvl.front();
2266   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable"));
2267   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.full"));
2268 }
2269 
2270 TEST_F(OpenMPIRBuilderTest, UnrollLoopPartial) {
2271   OpenMPIRBuilder OMPBuilder(*M);
2272   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2273 
2274   // Unroll the loop.
2275   CanonicalLoopInfo *UnrolledLoop = nullptr;
2276   OMPBuilder.unrollLoopPartial(DL, CLI, 5, &UnrolledLoop);
2277   ASSERT_NE(UnrolledLoop, nullptr);
2278 
2279   OMPBuilder.finalize();
2280   EXPECT_FALSE(verifyModule(*M, &errs()));
2281   UnrolledLoop->assertOK();
2282 
2283   PassBuilder PB;
2284   FunctionAnalysisManager FAM;
2285   PB.registerFunctionAnalyses(FAM);
2286   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2287 
2288   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2289   EXPECT_EQ(TopLvl.size(), 1u);
2290   Loop *Outer = TopLvl.front();
2291   EXPECT_EQ(Outer->getHeader(), UnrolledLoop->getHeader());
2292   EXPECT_EQ(Outer->getLoopLatch(), UnrolledLoop->getLatch());
2293   EXPECT_EQ(Outer->getExitingBlock(), UnrolledLoop->getCond());
2294   EXPECT_EQ(Outer->getExitBlock(), UnrolledLoop->getExit());
2295 
2296   EXPECT_EQ(Outer->getSubLoops().size(), 1u);
2297   Loop *Inner = Outer->getSubLoops().front();
2298 
2299   EXPECT_TRUE(getBooleanLoopAttribute(Inner, "llvm.loop.unroll.enable"));
2300   EXPECT_EQ(getIntLoopAttribute(Inner, "llvm.loop.unroll.count"), 5);
2301 }
2302 
2303 TEST_F(OpenMPIRBuilderTest, UnrollLoopHeuristic) {
2304   OpenMPIRBuilder OMPBuilder(*M);
2305 
2306   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
2307 
2308   // Unroll the loop.
2309   OMPBuilder.unrollLoopHeuristic(DL, CLI);
2310 
2311   OMPBuilder.finalize();
2312   EXPECT_FALSE(verifyModule(*M, &errs()));
2313 
2314   PassBuilder PB;
2315   FunctionAnalysisManager FAM;
2316   PB.registerFunctionAnalyses(FAM);
2317   LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
2318 
2319   const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
2320   EXPECT_EQ(TopLvl.size(), 1u);
2321 
2322   Loop *L = TopLvl.front();
2323   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable"));
2324 }
2325 
2326 TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
2327   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2328   std::string oldDLStr = M->getDataLayoutStr();
2329   M->setDataLayout(
2330       "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:"
2331       "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
2332       "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
2333   OpenMPIRBuilder OMPBuilder(*M);
2334   OMPBuilder.Config.IsTargetDevice = true;
2335   OMPBuilder.initialize();
2336   IRBuilder<> Builder(BB);
2337   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2338   InsertPointTy AllocaIP = Builder.saveIP();
2339 
2340   Type *LCTy = Type::getInt32Ty(Ctx);
2341   Value *StartVal = ConstantInt::get(LCTy, 10);
2342   Value *StopVal = ConstantInt::get(LCTy, 52);
2343   Value *StepVal = ConstantInt::get(LCTy, 2);
2344   auto LoopBodyGen = [&](InsertPointTy, Value *) { return Error::success(); };
2345 
2346   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
2347       Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false);
2348   assert(LoopResult && "unexpected error");
2349   CanonicalLoopInfo *CLI = *LoopResult;
2350   BasicBlock *Preheader = CLI->getPreheader();
2351   Value *TripCount = CLI->getTripCount();
2352 
2353   Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
2354 
2355   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.applyWorkshareLoop(
2356       DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false,
2357       false, false, WorksharingLoopType::ForStaticLoop);
2358   assert(AfterIP && "unexpected error");
2359   Builder.restoreIP(*AfterIP);
2360   Builder.CreateRetVoid();
2361 
2362   OMPBuilder.finalize();
2363   EXPECT_FALSE(verifyModule(*M, &errs()));
2364 
2365   CallInst *WorkshareLoopRuntimeCall = nullptr;
2366   int WorkshareLoopRuntimeCallCnt = 0;
2367   for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) {
2368     CallInst *Call = dyn_cast<CallInst>(Inst);
2369     if (!Call)
2370       continue;
2371     if (!Call->getCalledFunction())
2372       continue;
2373 
2374     if (Call->getCalledFunction()->getName() == "__kmpc_for_static_loop_4u") {
2375       WorkshareLoopRuntimeCall = Call;
2376       WorkshareLoopRuntimeCallCnt++;
2377     }
2378   }
2379   EXPECT_NE(WorkshareLoopRuntimeCall, nullptr);
2380   // Verify that there is only one call to workshare loop function
2381   EXPECT_EQ(WorkshareLoopRuntimeCallCnt, 1);
2382   // Check that pointer to loop body function is passed as second argument
2383   Value *LoopBodyFuncArg = WorkshareLoopRuntimeCall->getArgOperand(1);
2384   EXPECT_EQ(Builder.getPtrTy(), LoopBodyFuncArg->getType());
2385   Function *ArgFunction = dyn_cast<Function>(LoopBodyFuncArg);
2386   EXPECT_NE(ArgFunction, nullptr);
2387   EXPECT_EQ(ArgFunction->arg_size(), 1u);
2388   EXPECT_EQ(ArgFunction->getArg(0)->getType(), TripCount->getType());
2389   // Check that no variables except for loop counter are used in loop body
2390   EXPECT_EQ(Constant::getNullValue(Builder.getPtrTy()),
2391             WorkshareLoopRuntimeCall->getArgOperand(2));
2392   // Check loop trip count argument
2393   EXPECT_EQ(TripCount, WorkshareLoopRuntimeCall->getArgOperand(3));
2394 }
2395 
2396 TEST_F(OpenMPIRBuilderTest, StaticWorkShareLoop) {
2397   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2398   OpenMPIRBuilder OMPBuilder(*M);
2399   OMPBuilder.Config.IsTargetDevice = false;
2400   OMPBuilder.initialize();
2401   IRBuilder<> Builder(BB);
2402   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2403 
2404   Type *LCTy = Type::getInt32Ty(Ctx);
2405   Value *StartVal = ConstantInt::get(LCTy, 10);
2406   Value *StopVal = ConstantInt::get(LCTy, 52);
2407   Value *StepVal = ConstantInt::get(LCTy, 2);
2408   auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {
2409     return Error::success();
2410   };
2411 
2412   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
2413       Loc, LoopBodyGen, StartVal, StopVal, StepVal,
2414       /*IsSigned=*/false, /*InclusiveStop=*/false);
2415   assert(LoopResult && "unexpected error");
2416   CanonicalLoopInfo *CLI = *LoopResult;
2417   BasicBlock *Preheader = CLI->getPreheader();
2418   BasicBlock *Body = CLI->getBody();
2419   Value *IV = CLI->getIndVar();
2420   BasicBlock *ExitBlock = CLI->getExit();
2421 
2422   Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
2423   InsertPointTy AllocaIP = Builder.saveIP();
2424 
2425   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.applyWorkshareLoop(
2426       DL, CLI, AllocaIP, /*NeedsBarrier=*/true, OMP_SCHEDULE_Static);
2427   assert(AfterIP && "unexpected error");
2428 
2429   BasicBlock *Cond = Body->getSinglePredecessor();
2430   Instruction *Cmp = &*Cond->begin();
2431   Value *TripCount = Cmp->getOperand(1);
2432 
2433   auto AllocaIter = BB->begin();
2434   ASSERT_GE(std::distance(BB->begin(), BB->end()), 4);
2435   AllocaInst *PLastIter = dyn_cast<AllocaInst>(&*(AllocaIter++));
2436   AllocaInst *PLowerBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
2437   AllocaInst *PUpperBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
2438   AllocaInst *PStride = dyn_cast<AllocaInst>(&*(AllocaIter++));
2439   EXPECT_NE(PLastIter, nullptr);
2440   EXPECT_NE(PLowerBound, nullptr);
2441   EXPECT_NE(PUpperBound, nullptr);
2442   EXPECT_NE(PStride, nullptr);
2443 
2444   auto PreheaderIter = Preheader->begin();
2445   ASSERT_GE(std::distance(Preheader->begin(), Preheader->end()), 7);
2446   StoreInst *LowerBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2447   StoreInst *UpperBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2448   StoreInst *StrideStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2449   ASSERT_NE(LowerBoundStore, nullptr);
2450   ASSERT_NE(UpperBoundStore, nullptr);
2451   ASSERT_NE(StrideStore, nullptr);
2452 
2453   auto *OrigLowerBound =
2454       dyn_cast<ConstantInt>(LowerBoundStore->getValueOperand());
2455   auto *OrigUpperBound =
2456       dyn_cast<ConstantInt>(UpperBoundStore->getValueOperand());
2457   auto *OrigStride = dyn_cast<ConstantInt>(StrideStore->getValueOperand());
2458   ASSERT_NE(OrigLowerBound, nullptr);
2459   ASSERT_NE(OrigUpperBound, nullptr);
2460   ASSERT_NE(OrigStride, nullptr);
2461   EXPECT_EQ(OrigLowerBound->getValue(), 0);
2462   EXPECT_EQ(OrigUpperBound->getValue(), 20);
2463   EXPECT_EQ(OrigStride->getValue(), 1);
2464 
2465   // Check that the loop IV is updated to account for the lower bound returned
2466   // by the OpenMP runtime call.
2467   BinaryOperator *Add = dyn_cast<BinaryOperator>(&Body->front());
2468   EXPECT_EQ(Add->getOperand(0), IV);
2469   auto *LoadedLowerBound = dyn_cast<LoadInst>(Add->getOperand(1));
2470   ASSERT_NE(LoadedLowerBound, nullptr);
2471   EXPECT_EQ(LoadedLowerBound->getPointerOperand(), PLowerBound);
2472 
2473   // Check that the trip count is updated to account for the lower and upper
2474   // bounds return by the OpenMP runtime call.
2475   auto *AddOne = dyn_cast<Instruction>(TripCount);
2476   ASSERT_NE(AddOne, nullptr);
2477   ASSERT_TRUE(AddOne->isBinaryOp());
2478   auto *One = dyn_cast<ConstantInt>(AddOne->getOperand(1));
2479   ASSERT_NE(One, nullptr);
2480   EXPECT_EQ(One->getValue(), 1);
2481   auto *Difference = dyn_cast<Instruction>(AddOne->getOperand(0));
2482   ASSERT_NE(Difference, nullptr);
2483   ASSERT_TRUE(Difference->isBinaryOp());
2484   EXPECT_EQ(Difference->getOperand(1), LoadedLowerBound);
2485   auto *LoadedUpperBound = dyn_cast<LoadInst>(Difference->getOperand(0));
2486   ASSERT_NE(LoadedUpperBound, nullptr);
2487   EXPECT_EQ(LoadedUpperBound->getPointerOperand(), PUpperBound);
2488 
2489   // The original loop iterator should only be used in the condition, in the
2490   // increment and in the statement that adds the lower bound to it.
2491   EXPECT_EQ(std::distance(IV->use_begin(), IV->use_end()), 3);
2492 
2493   // The exit block should contain the "fini" call and the barrier call,
2494   // plus the call to obtain the thread ID.
2495   size_t NumCallsInExitBlock =
2496       count_if(*ExitBlock, [](Instruction &I) { return isa<CallInst>(I); });
2497   EXPECT_EQ(NumCallsInExitBlock, 3u);
2498 }
2499 
2500 TEST_P(OpenMPIRBuilderTestWithIVBits, StaticChunkedWorkshareLoop) {
2501   unsigned IVBits = GetParam();
2502 
2503   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2504   OpenMPIRBuilder OMPBuilder(*M);
2505   OMPBuilder.Config.IsTargetDevice = false;
2506 
2507   BasicBlock *Body;
2508   CallInst *Call;
2509   CanonicalLoopInfo *CLI =
2510       buildSingleLoopFunction(DL, OMPBuilder, IVBits, &Call, &Body);
2511 
2512   Instruction *OrigIndVar = CLI->getIndVar();
2513   EXPECT_EQ(Call->getOperand(1), OrigIndVar);
2514 
2515   Type *LCTy = Type::getInt32Ty(Ctx);
2516   Value *ChunkSize = ConstantInt::get(LCTy, 5);
2517   InsertPointTy AllocaIP{&F->getEntryBlock(),
2518                          F->getEntryBlock().getFirstInsertionPt()};
2519   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.applyWorkshareLoop(
2520       DL, CLI, AllocaIP, /*NeedsBarrier=*/true, OMP_SCHEDULE_Static, ChunkSize);
2521   assert(AfterIP && "unexpected error");
2522 
2523   OMPBuilder.finalize();
2524   EXPECT_FALSE(verifyModule(*M, &errs()));
2525 
2526   BasicBlock *Entry = &F->getEntryBlock();
2527   BasicBlock *Preheader = Entry->getSingleSuccessor();
2528 
2529   BasicBlock *DispatchPreheader = Preheader->getSingleSuccessor();
2530   BasicBlock *DispatchHeader = DispatchPreheader->getSingleSuccessor();
2531   BasicBlock *DispatchCond = DispatchHeader->getSingleSuccessor();
2532   BasicBlock *DispatchBody = succ_begin(DispatchCond)[0];
2533   BasicBlock *DispatchExit = succ_begin(DispatchCond)[1];
2534   BasicBlock *DispatchAfter = DispatchExit->getSingleSuccessor();
2535   BasicBlock *Return = DispatchAfter->getSingleSuccessor();
2536 
2537   BasicBlock *ChunkPreheader = DispatchBody->getSingleSuccessor();
2538   BasicBlock *ChunkHeader = ChunkPreheader->getSingleSuccessor();
2539   BasicBlock *ChunkCond = ChunkHeader->getSingleSuccessor();
2540   BasicBlock *ChunkBody = succ_begin(ChunkCond)[0];
2541   BasicBlock *ChunkExit = succ_begin(ChunkCond)[1];
2542   BasicBlock *ChunkInc = ChunkBody->getSingleSuccessor();
2543   BasicBlock *ChunkAfter = ChunkExit->getSingleSuccessor();
2544 
2545   BasicBlock *DispatchInc = ChunkAfter;
2546 
2547   EXPECT_EQ(ChunkBody, Body);
2548   EXPECT_EQ(ChunkInc->getSingleSuccessor(), ChunkHeader);
2549   EXPECT_EQ(DispatchInc->getSingleSuccessor(), DispatchHeader);
2550 
2551   EXPECT_TRUE(isa<ReturnInst>(Return->front()));
2552 
2553   Value *NewIV = Call->getOperand(1);
2554   EXPECT_EQ(NewIV->getType()->getScalarSizeInBits(), IVBits);
2555 
2556   CallInst *InitCall = findSingleCall(
2557       F,
2558       (IVBits > 32) ? omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u
2559                     : omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u,
2560       OMPBuilder);
2561   EXPECT_EQ(InitCall->getParent(), Preheader);
2562   EXPECT_EQ(cast<ConstantInt>(InitCall->getArgOperand(2))->getSExtValue(), 33);
2563   EXPECT_EQ(cast<ConstantInt>(InitCall->getArgOperand(7))->getSExtValue(), 1);
2564   EXPECT_EQ(cast<ConstantInt>(InitCall->getArgOperand(8))->getSExtValue(), 5);
2565 
2566   CallInst *FiniCall = findSingleCall(
2567       F, omp::RuntimeFunction::OMPRTL___kmpc_for_static_fini, OMPBuilder);
2568   EXPECT_EQ(FiniCall->getParent(), DispatchExit);
2569 
2570   CallInst *BarrierCall = findSingleCall(
2571       F, omp::RuntimeFunction::OMPRTL___kmpc_barrier, OMPBuilder);
2572   EXPECT_EQ(BarrierCall->getParent(), DispatchExit);
2573 }
2574 
2575 INSTANTIATE_TEST_SUITE_P(IVBits, OpenMPIRBuilderTestWithIVBits,
2576                          ::testing::Values(8, 16, 32, 64));
2577 
2578 TEST_P(OpenMPIRBuilderTestWithParams, DynamicWorkShareLoop) {
2579   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2580   OpenMPIRBuilder OMPBuilder(*M);
2581   OMPBuilder.Config.IsTargetDevice = false;
2582   OMPBuilder.initialize();
2583   IRBuilder<> Builder(BB);
2584   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2585 
2586   omp::OMPScheduleType SchedType = GetParam();
2587   uint32_t ChunkSize = 1;
2588   switch (SchedType & ~OMPScheduleType::ModifierMask) {
2589   case omp::OMPScheduleType::BaseDynamicChunked:
2590   case omp::OMPScheduleType::BaseGuidedChunked:
2591     ChunkSize = 7;
2592     break;
2593   case omp::OMPScheduleType::BaseAuto:
2594   case omp::OMPScheduleType::BaseRuntime:
2595     ChunkSize = 1;
2596     break;
2597   default:
2598     assert(0 && "unknown type for this test");
2599     break;
2600   }
2601 
2602   Type *LCTy = Type::getInt32Ty(Ctx);
2603   Value *StartVal = ConstantInt::get(LCTy, 10);
2604   Value *StopVal = ConstantInt::get(LCTy, 52);
2605   Value *StepVal = ConstantInt::get(LCTy, 2);
2606   Value *ChunkVal =
2607       (ChunkSize == 1) ? nullptr : ConstantInt::get(LCTy, ChunkSize);
2608   auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {
2609     return Error::success();
2610   };
2611 
2612   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
2613       Loc, LoopBodyGen, StartVal, StopVal, StepVal,
2614       /*IsSigned=*/false, /*InclusiveStop=*/false);
2615   assert(LoopResult && "unexpected error");
2616   CanonicalLoopInfo *CLI = *LoopResult;
2617 
2618   Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
2619   InsertPointTy AllocaIP = Builder.saveIP();
2620 
2621   // Collect all the info from CLI, as it isn't usable after the call to
2622   // createDynamicWorkshareLoop.
2623   InsertPointTy AfterIP = CLI->getAfterIP();
2624   BasicBlock *Preheader = CLI->getPreheader();
2625   BasicBlock *ExitBlock = CLI->getExit();
2626   BasicBlock *LatchBlock = CLI->getLatch();
2627   Value *IV = CLI->getIndVar();
2628 
2629   OpenMPIRBuilder::InsertPointOrErrorTy EndIP = OMPBuilder.applyWorkshareLoop(
2630       DL, CLI, AllocaIP, /*NeedsBarrier=*/true, getSchedKind(SchedType),
2631       ChunkVal, /*Simd=*/false,
2632       (SchedType & omp::OMPScheduleType::ModifierMonotonic) ==
2633           omp::OMPScheduleType::ModifierMonotonic,
2634       (SchedType & omp::OMPScheduleType::ModifierNonmonotonic) ==
2635           omp::OMPScheduleType::ModifierNonmonotonic,
2636       /*Ordered=*/false);
2637   assert(EndIP && "unexpected error");
2638 
2639   // The returned value should be the "after" point.
2640   ASSERT_EQ(EndIP->getBlock(), AfterIP.getBlock());
2641   ASSERT_EQ(EndIP->getPoint(), AfterIP.getPoint());
2642 
2643   auto AllocaIter = BB->begin();
2644   ASSERT_GE(std::distance(BB->begin(), BB->end()), 4);
2645   AllocaInst *PLastIter = dyn_cast<AllocaInst>(&*(AllocaIter++));
2646   AllocaInst *PLowerBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
2647   AllocaInst *PUpperBound = dyn_cast<AllocaInst>(&*(AllocaIter++));
2648   AllocaInst *PStride = dyn_cast<AllocaInst>(&*(AllocaIter++));
2649   EXPECT_NE(PLastIter, nullptr);
2650   EXPECT_NE(PLowerBound, nullptr);
2651   EXPECT_NE(PUpperBound, nullptr);
2652   EXPECT_NE(PStride, nullptr);
2653 
2654   auto PreheaderIter = Preheader->begin();
2655   ASSERT_GE(std::distance(Preheader->begin(), Preheader->end()), 6);
2656   StoreInst *LowerBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2657   StoreInst *UpperBoundStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2658   StoreInst *StrideStore = dyn_cast<StoreInst>(&*(PreheaderIter++));
2659   ASSERT_NE(LowerBoundStore, nullptr);
2660   ASSERT_NE(UpperBoundStore, nullptr);
2661   ASSERT_NE(StrideStore, nullptr);
2662 
2663   CallInst *ThreadIdCall = dyn_cast<CallInst>(&*(PreheaderIter++));
2664   ASSERT_NE(ThreadIdCall, nullptr);
2665   EXPECT_EQ(ThreadIdCall->getCalledFunction()->getName(),
2666             "__kmpc_global_thread_num");
2667 
2668   CallInst *InitCall = dyn_cast<CallInst>(&*PreheaderIter);
2669 
2670   ASSERT_NE(InitCall, nullptr);
2671   EXPECT_EQ(InitCall->getCalledFunction()->getName(),
2672             "__kmpc_dispatch_init_4u");
2673   EXPECT_EQ(InitCall->arg_size(), 7U);
2674   EXPECT_EQ(InitCall->getArgOperand(6), ConstantInt::get(LCTy, ChunkSize));
2675   ConstantInt *SchedVal = cast<ConstantInt>(InitCall->getArgOperand(2));
2676   if ((SchedType & OMPScheduleType::MonotonicityMask) ==
2677       OMPScheduleType::None) {
2678     // Implementation is allowed to add default nonmonotonicity flag
2679     EXPECT_EQ(
2680         static_cast<OMPScheduleType>(SchedVal->getValue().getZExtValue()) |
2681             OMPScheduleType::ModifierNonmonotonic,
2682         SchedType | OMPScheduleType::ModifierNonmonotonic);
2683   } else {
2684     EXPECT_EQ(static_cast<OMPScheduleType>(SchedVal->getValue().getZExtValue()),
2685               SchedType);
2686   }
2687 
2688   ConstantInt *OrigLowerBound =
2689       dyn_cast<ConstantInt>(LowerBoundStore->getValueOperand());
2690   ConstantInt *OrigUpperBound =
2691       dyn_cast<ConstantInt>(UpperBoundStore->getValueOperand());
2692   ConstantInt *OrigStride =
2693       dyn_cast<ConstantInt>(StrideStore->getValueOperand());
2694   ASSERT_NE(OrigLowerBound, nullptr);
2695   ASSERT_NE(OrigUpperBound, nullptr);
2696   ASSERT_NE(OrigStride, nullptr);
2697   EXPECT_EQ(OrigLowerBound->getValue(), 1);
2698   EXPECT_EQ(OrigUpperBound->getValue(), 21);
2699   EXPECT_EQ(OrigStride->getValue(), 1);
2700 
2701   CallInst *FiniCall = dyn_cast<CallInst>(
2702       &*(LatchBlock->getTerminator()->getPrevNonDebugInstruction(true)));
2703   EXPECT_EQ(FiniCall, nullptr);
2704 
2705   // The original loop iterator should only be used in the condition, in the
2706   // increment and in the statement that adds the lower bound to it.
2707   EXPECT_EQ(std::distance(IV->use_begin(), IV->use_end()), 3);
2708 
2709   // The exit block should contain the barrier call, plus the call to obtain
2710   // the thread ID.
2711   size_t NumCallsInExitBlock =
2712       count_if(*ExitBlock, [](Instruction &I) { return isa<CallInst>(I); });
2713   EXPECT_EQ(NumCallsInExitBlock, 2u);
2714 
2715   // Add a termination to our block and check that it is internally consistent.
2716   Builder.restoreIP(*EndIP);
2717   Builder.CreateRetVoid();
2718   OMPBuilder.finalize();
2719   EXPECT_FALSE(verifyModule(*M, &errs()));
2720 }
2721 
2722 INSTANTIATE_TEST_SUITE_P(
2723     OpenMPWSLoopSchedulingTypes, OpenMPIRBuilderTestWithParams,
2724     ::testing::Values(omp::OMPScheduleType::UnorderedDynamicChunked,
2725                       omp::OMPScheduleType::UnorderedGuidedChunked,
2726                       omp::OMPScheduleType::UnorderedAuto,
2727                       omp::OMPScheduleType::UnorderedRuntime,
2728                       omp::OMPScheduleType::UnorderedDynamicChunked |
2729                           omp::OMPScheduleType::ModifierMonotonic,
2730                       omp::OMPScheduleType::UnorderedDynamicChunked |
2731                           omp::OMPScheduleType::ModifierNonmonotonic,
2732                       omp::OMPScheduleType::UnorderedGuidedChunked |
2733                           omp::OMPScheduleType::ModifierMonotonic,
2734                       omp::OMPScheduleType::UnorderedGuidedChunked |
2735                           omp::OMPScheduleType::ModifierNonmonotonic,
2736                       omp::OMPScheduleType::UnorderedAuto |
2737                           omp::OMPScheduleType::ModifierMonotonic,
2738                       omp::OMPScheduleType::UnorderedRuntime |
2739                           omp::OMPScheduleType::ModifierMonotonic));
2740 
2741 TEST_F(OpenMPIRBuilderTest, DynamicWorkShareLoopOrdered) {
2742   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2743   OpenMPIRBuilder OMPBuilder(*M);
2744   OMPBuilder.Config.IsTargetDevice = false;
2745   OMPBuilder.initialize();
2746   IRBuilder<> Builder(BB);
2747   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2748 
2749   uint32_t ChunkSize = 1;
2750   Type *LCTy = Type::getInt32Ty(Ctx);
2751   Value *StartVal = ConstantInt::get(LCTy, 10);
2752   Value *StopVal = ConstantInt::get(LCTy, 52);
2753   Value *StepVal = ConstantInt::get(LCTy, 2);
2754   Value *ChunkVal = ConstantInt::get(LCTy, ChunkSize);
2755   auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {
2756     return llvm::Error::success();
2757   };
2758 
2759   Expected<CanonicalLoopInfo *> LoopResult = OMPBuilder.createCanonicalLoop(
2760       Loc, LoopBodyGen, StartVal, StopVal, StepVal,
2761       /*IsSigned=*/false, /*InclusiveStop=*/false);
2762   assert(LoopResult && "unexpected error");
2763   CanonicalLoopInfo *CLI = *LoopResult;
2764 
2765   Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
2766   InsertPointTy AllocaIP = Builder.saveIP();
2767 
2768   // Collect all the info from CLI, as it isn't usable after the call to
2769   // createDynamicWorkshareLoop.
2770   BasicBlock *Preheader = CLI->getPreheader();
2771   BasicBlock *ExitBlock = CLI->getExit();
2772   BasicBlock *LatchBlock = CLI->getLatch();
2773   Value *IV = CLI->getIndVar();
2774 
2775   OpenMPIRBuilder::InsertPointOrErrorTy EndIP = OMPBuilder.applyWorkshareLoop(
2776       DL, CLI, AllocaIP, /*NeedsBarrier=*/true, OMP_SCHEDULE_Static, ChunkVal,
2777       /*HasSimdModifier=*/false, /*HasMonotonicModifier=*/false,
2778       /*HasNonmonotonicModifier=*/false,
2779       /*HasOrderedClause=*/true);
2780   assert(EndIP && "unexpected error");
2781 
2782   // Add a termination to our block and check that it is internally consistent.
2783   Builder.restoreIP(*EndIP);
2784   Builder.CreateRetVoid();
2785   OMPBuilder.finalize();
2786   EXPECT_FALSE(verifyModule(*M, &errs()));
2787 
2788   CallInst *InitCall = nullptr;
2789   for (Instruction &EI : *Preheader) {
2790     Instruction *Cur = &EI;
2791     if (isa<CallInst>(Cur)) {
2792       InitCall = cast<CallInst>(Cur);
2793       if (InitCall->getCalledFunction()->getName() == "__kmpc_dispatch_init_4u")
2794         break;
2795       InitCall = nullptr;
2796     }
2797   }
2798   EXPECT_NE(InitCall, nullptr);
2799   EXPECT_EQ(InitCall->arg_size(), 7U);
2800   ConstantInt *SchedVal = cast<ConstantInt>(InitCall->getArgOperand(2));
2801   EXPECT_EQ(SchedVal->getValue(),
2802             static_cast<uint64_t>(OMPScheduleType::OrderedStaticChunked));
2803 
2804   CallInst *FiniCall = dyn_cast<CallInst>(
2805       &*(LatchBlock->getTerminator()->getPrevNonDebugInstruction(true)));
2806   ASSERT_NE(FiniCall, nullptr);
2807   EXPECT_EQ(FiniCall->getCalledFunction()->getName(),
2808             "__kmpc_dispatch_fini_4u");
2809   EXPECT_EQ(FiniCall->arg_size(), 2U);
2810   EXPECT_EQ(InitCall->getArgOperand(0), FiniCall->getArgOperand(0));
2811   EXPECT_EQ(InitCall->getArgOperand(1), FiniCall->getArgOperand(1));
2812 
2813   // The original loop iterator should only be used in the condition, in the
2814   // increment and in the statement that adds the lower bound to it.
2815   EXPECT_EQ(std::distance(IV->use_begin(), IV->use_end()), 3);
2816 
2817   // The exit block should contain the barrier call, plus the call to obtain
2818   // the thread ID.
2819   size_t NumCallsInExitBlock =
2820       count_if(*ExitBlock, [](Instruction &I) { return isa<CallInst>(I); });
2821   EXPECT_EQ(NumCallsInExitBlock, 2u);
2822 }
2823 
2824 TEST_F(OpenMPIRBuilderTest, MasterDirective) {
2825   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2826   OpenMPIRBuilder OMPBuilder(*M);
2827   OMPBuilder.initialize();
2828   F->setName("func");
2829   IRBuilder<> Builder(BB);
2830 
2831   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2832 
2833   AllocaInst *PrivAI = nullptr;
2834 
2835   BasicBlock *EntryBB = nullptr;
2836   BasicBlock *ThenBB = nullptr;
2837 
2838   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
2839     if (AllocaIP.isSet())
2840       Builder.restoreIP(AllocaIP);
2841     else
2842       Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
2843     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
2844     Builder.CreateStore(F->arg_begin(), PrivAI);
2845 
2846     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
2847     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
2848     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
2849 
2850     Builder.restoreIP(CodeGenIP);
2851 
2852     // collect some info for checks later
2853     ThenBB = Builder.GetInsertBlock();
2854     EntryBB = ThenBB->getUniquePredecessor();
2855 
2856     // simple instructions for body
2857     Value *PrivLoad =
2858         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
2859     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
2860   };
2861 
2862   auto FiniCB = [&](InsertPointTy IP) {
2863     BasicBlock *IPBB = IP.getBlock();
2864     EXPECT_NE(IPBB->end(), IP.getPoint());
2865   };
2866 
2867   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createMaster(
2868       Builder, BODYGENCB_WRAPPER(BodyGenCB), FINICB_WRAPPER(FiniCB));
2869   assert(AfterIP && "unexpected error");
2870   Builder.restoreIP(*AfterIP);
2871   Value *EntryBBTI = EntryBB->getTerminator();
2872   EXPECT_NE(EntryBBTI, nullptr);
2873   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
2874   BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
2875   EXPECT_TRUE(EntryBr->isConditional());
2876   EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
2877   BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
2878   EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
2879 
2880   CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
2881   EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
2882 
2883   CallInst *MasterEntryCI = cast<CallInst>(CondInst->getOperand(0));
2884   EXPECT_EQ(MasterEntryCI->arg_size(), 2U);
2885   EXPECT_EQ(MasterEntryCI->getCalledFunction()->getName(), "__kmpc_master");
2886   EXPECT_TRUE(isa<GlobalVariable>(MasterEntryCI->getArgOperand(0)));
2887 
2888   CallInst *MasterEndCI = nullptr;
2889   for (auto &FI : *ThenBB) {
2890     Instruction *cur = &FI;
2891     if (isa<CallInst>(cur)) {
2892       MasterEndCI = cast<CallInst>(cur);
2893       if (MasterEndCI->getCalledFunction()->getName() == "__kmpc_end_master")
2894         break;
2895       MasterEndCI = nullptr;
2896     }
2897   }
2898   EXPECT_NE(MasterEndCI, nullptr);
2899   EXPECT_EQ(MasterEndCI->arg_size(), 2U);
2900   EXPECT_TRUE(isa<GlobalVariable>(MasterEndCI->getArgOperand(0)));
2901   EXPECT_EQ(MasterEndCI->getArgOperand(1), MasterEntryCI->getArgOperand(1));
2902 }
2903 
2904 TEST_F(OpenMPIRBuilderTest, MaskedDirective) {
2905   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2906   OpenMPIRBuilder OMPBuilder(*M);
2907   OMPBuilder.initialize();
2908   F->setName("func");
2909   IRBuilder<> Builder(BB);
2910 
2911   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2912 
2913   AllocaInst *PrivAI = nullptr;
2914 
2915   BasicBlock *EntryBB = nullptr;
2916   BasicBlock *ThenBB = nullptr;
2917 
2918   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
2919     if (AllocaIP.isSet())
2920       Builder.restoreIP(AllocaIP);
2921     else
2922       Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
2923     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
2924     Builder.CreateStore(F->arg_begin(), PrivAI);
2925 
2926     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
2927     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
2928     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
2929 
2930     Builder.restoreIP(CodeGenIP);
2931 
2932     // collect some info for checks later
2933     ThenBB = Builder.GetInsertBlock();
2934     EntryBB = ThenBB->getUniquePredecessor();
2935 
2936     // simple instructions for body
2937     Value *PrivLoad =
2938         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
2939     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
2940   };
2941 
2942   auto FiniCB = [&](InsertPointTy IP) {
2943     BasicBlock *IPBB = IP.getBlock();
2944     EXPECT_NE(IPBB->end(), IP.getPoint());
2945   };
2946 
2947   Constant *Filter = ConstantInt::get(Type::getInt32Ty(M->getContext()), 0);
2948   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createMasked(
2949       Builder, BODYGENCB_WRAPPER(BodyGenCB), FINICB_WRAPPER(FiniCB), Filter);
2950   assert(AfterIP && "unexpected error");
2951   Builder.restoreIP(*AfterIP);
2952   Value *EntryBBTI = EntryBB->getTerminator();
2953   EXPECT_NE(EntryBBTI, nullptr);
2954   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
2955   BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
2956   EXPECT_TRUE(EntryBr->isConditional());
2957   EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
2958   BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
2959   EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
2960 
2961   CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
2962   EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
2963 
2964   CallInst *MaskedEntryCI = cast<CallInst>(CondInst->getOperand(0));
2965   EXPECT_EQ(MaskedEntryCI->arg_size(), 3U);
2966   EXPECT_EQ(MaskedEntryCI->getCalledFunction()->getName(), "__kmpc_masked");
2967   EXPECT_TRUE(isa<GlobalVariable>(MaskedEntryCI->getArgOperand(0)));
2968 
2969   CallInst *MaskedEndCI = nullptr;
2970   for (auto &FI : *ThenBB) {
2971     Instruction *cur = &FI;
2972     if (isa<CallInst>(cur)) {
2973       MaskedEndCI = cast<CallInst>(cur);
2974       if (MaskedEndCI->getCalledFunction()->getName() == "__kmpc_end_masked")
2975         break;
2976       MaskedEndCI = nullptr;
2977     }
2978   }
2979   EXPECT_NE(MaskedEndCI, nullptr);
2980   EXPECT_EQ(MaskedEndCI->arg_size(), 2U);
2981   EXPECT_TRUE(isa<GlobalVariable>(MaskedEndCI->getArgOperand(0)));
2982   EXPECT_EQ(MaskedEndCI->getArgOperand(1), MaskedEntryCI->getArgOperand(1));
2983 }
2984 
2985 TEST_F(OpenMPIRBuilderTest, CriticalDirective) {
2986   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
2987   OpenMPIRBuilder OMPBuilder(*M);
2988   OMPBuilder.initialize();
2989   F->setName("func");
2990   IRBuilder<> Builder(BB);
2991 
2992   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
2993 
2994   AllocaInst *PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
2995 
2996   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
2997     // actual start for bodyCB
2998     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
2999     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3000     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3001 
3002     // body begin
3003     Builder.restoreIP(CodeGenIP);
3004     Builder.CreateStore(F->arg_begin(), PrivAI);
3005     Value *PrivLoad =
3006         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3007     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3008   };
3009 
3010   auto FiniCB = [&](InsertPointTy IP) {
3011     BasicBlock *IPBB = IP.getBlock();
3012     EXPECT_NE(IPBB->end(), IP.getPoint());
3013   };
3014   BasicBlock *EntryBB = Builder.GetInsertBlock();
3015 
3016   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3017       OMPBuilder.createCritical(Builder, BODYGENCB_WRAPPER(BodyGenCB),
3018                                 FINICB_WRAPPER(FiniCB), "testCRT", nullptr);
3019   assert(AfterIP && "unexpected error");
3020   Builder.restoreIP(*AfterIP);
3021 
3022   CallInst *CriticalEntryCI = nullptr;
3023   for (auto &EI : *EntryBB) {
3024     Instruction *cur = &EI;
3025     if (isa<CallInst>(cur)) {
3026       CriticalEntryCI = cast<CallInst>(cur);
3027       if (CriticalEntryCI->getCalledFunction()->getName() == "__kmpc_critical")
3028         break;
3029       CriticalEntryCI = nullptr;
3030     }
3031   }
3032   EXPECT_NE(CriticalEntryCI, nullptr);
3033   EXPECT_EQ(CriticalEntryCI->arg_size(), 3U);
3034   EXPECT_EQ(CriticalEntryCI->getCalledFunction()->getName(), "__kmpc_critical");
3035   EXPECT_TRUE(isa<GlobalVariable>(CriticalEntryCI->getArgOperand(0)));
3036 
3037   CallInst *CriticalEndCI = nullptr;
3038   for (auto &FI : *EntryBB) {
3039     Instruction *cur = &FI;
3040     if (isa<CallInst>(cur)) {
3041       CriticalEndCI = cast<CallInst>(cur);
3042       if (CriticalEndCI->getCalledFunction()->getName() ==
3043           "__kmpc_end_critical")
3044         break;
3045       CriticalEndCI = nullptr;
3046     }
3047   }
3048   EXPECT_NE(CriticalEndCI, nullptr);
3049   EXPECT_EQ(CriticalEndCI->arg_size(), 3U);
3050   EXPECT_TRUE(isa<GlobalVariable>(CriticalEndCI->getArgOperand(0)));
3051   EXPECT_EQ(CriticalEndCI->getArgOperand(1), CriticalEntryCI->getArgOperand(1));
3052   PointerType *CriticalNamePtrTy =
3053       PointerType::getUnqual(ArrayType::get(Type::getInt32Ty(Ctx), 8));
3054   EXPECT_EQ(CriticalEndCI->getArgOperand(2), CriticalEntryCI->getArgOperand(2));
3055   GlobalVariable *GV =
3056       dyn_cast<GlobalVariable>(CriticalEndCI->getArgOperand(2));
3057   ASSERT_NE(GV, nullptr);
3058   EXPECT_EQ(GV->getType(), CriticalNamePtrTy);
3059   const DataLayout &DL = M->getDataLayout();
3060   const llvm::Align TypeAlign = DL.getABITypeAlign(CriticalNamePtrTy);
3061   const llvm::Align PtrAlign = DL.getPointerABIAlignment(GV->getAddressSpace());
3062   if (const llvm::MaybeAlign Alignment = GV->getAlign())
3063     EXPECT_EQ(*Alignment, std::max(TypeAlign, PtrAlign));
3064 }
3065 
3066 TEST_F(OpenMPIRBuilderTest, OrderedDirectiveDependSource) {
3067   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3068   OpenMPIRBuilder OMPBuilder(*M);
3069   OMPBuilder.initialize();
3070   F->setName("func");
3071   IRBuilder<> Builder(BB);
3072   LLVMContext &Ctx = M->getContext();
3073 
3074   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3075 
3076   InsertPointTy AllocaIP(&F->getEntryBlock(),
3077                          F->getEntryBlock().getFirstInsertionPt());
3078 
3079   unsigned NumLoops = 2;
3080   SmallVector<Value *, 2> StoreValues;
3081   Type *LCTy = Type::getInt64Ty(Ctx);
3082   StoreValues.emplace_back(ConstantInt::get(LCTy, 1));
3083   StoreValues.emplace_back(ConstantInt::get(LCTy, 2));
3084 
3085   // Test for "#omp ordered depend(source)"
3086   Builder.restoreIP(OMPBuilder.createOrderedDepend(Builder, AllocaIP, NumLoops,
3087                                                    StoreValues, ".cnt.addr",
3088                                                    /*IsDependSource=*/true));
3089 
3090   Builder.CreateRetVoid();
3091   OMPBuilder.finalize();
3092   EXPECT_FALSE(verifyModule(*M, &errs()));
3093 
3094   AllocaInst *AllocInst = dyn_cast<AllocaInst>(&BB->front());
3095   ASSERT_NE(AllocInst, nullptr);
3096   ArrayType *ArrType = dyn_cast<ArrayType>(AllocInst->getAllocatedType());
3097   EXPECT_EQ(ArrType->getNumElements(), NumLoops);
3098   EXPECT_TRUE(
3099       AllocInst->getAllocatedType()->getArrayElementType()->isIntegerTy(64));
3100 
3101   Instruction *IterInst = dyn_cast<Instruction>(AllocInst);
3102   for (unsigned Iter = 0; Iter < NumLoops; Iter++) {
3103     GetElementPtrInst *DependAddrGEPIter =
3104         dyn_cast<GetElementPtrInst>(IterInst->getNextNode());
3105     ASSERT_NE(DependAddrGEPIter, nullptr);
3106     EXPECT_EQ(DependAddrGEPIter->getPointerOperand(), AllocInst);
3107     EXPECT_EQ(DependAddrGEPIter->getNumIndices(), (unsigned)2);
3108     auto *FirstIdx = dyn_cast<ConstantInt>(DependAddrGEPIter->getOperand(1));
3109     auto *SecondIdx = dyn_cast<ConstantInt>(DependAddrGEPIter->getOperand(2));
3110     ASSERT_NE(FirstIdx, nullptr);
3111     ASSERT_NE(SecondIdx, nullptr);
3112     EXPECT_EQ(FirstIdx->getValue(), 0);
3113     EXPECT_EQ(SecondIdx->getValue(), Iter);
3114     StoreInst *StoreValue =
3115         dyn_cast<StoreInst>(DependAddrGEPIter->getNextNode());
3116     ASSERT_NE(StoreValue, nullptr);
3117     EXPECT_EQ(StoreValue->getValueOperand(), StoreValues[Iter]);
3118     EXPECT_EQ(StoreValue->getPointerOperand(), DependAddrGEPIter);
3119     EXPECT_EQ(StoreValue->getAlign(), Align(8));
3120     IterInst = dyn_cast<Instruction>(StoreValue);
3121   }
3122 
3123   GetElementPtrInst *DependBaseAddrGEP =
3124       dyn_cast<GetElementPtrInst>(IterInst->getNextNode());
3125   ASSERT_NE(DependBaseAddrGEP, nullptr);
3126   EXPECT_EQ(DependBaseAddrGEP->getPointerOperand(), AllocInst);
3127   EXPECT_EQ(DependBaseAddrGEP->getNumIndices(), (unsigned)2);
3128   auto *FirstIdx = dyn_cast<ConstantInt>(DependBaseAddrGEP->getOperand(1));
3129   auto *SecondIdx = dyn_cast<ConstantInt>(DependBaseAddrGEP->getOperand(2));
3130   ASSERT_NE(FirstIdx, nullptr);
3131   ASSERT_NE(SecondIdx, nullptr);
3132   EXPECT_EQ(FirstIdx->getValue(), 0);
3133   EXPECT_EQ(SecondIdx->getValue(), 0);
3134 
3135   CallInst *GTID = dyn_cast<CallInst>(DependBaseAddrGEP->getNextNode());
3136   ASSERT_NE(GTID, nullptr);
3137   EXPECT_EQ(GTID->arg_size(), 1U);
3138   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
3139   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
3140   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
3141 
3142   CallInst *Depend = dyn_cast<CallInst>(GTID->getNextNode());
3143   ASSERT_NE(Depend, nullptr);
3144   EXPECT_EQ(Depend->arg_size(), 3U);
3145   EXPECT_EQ(Depend->getCalledFunction()->getName(), "__kmpc_doacross_post");
3146   EXPECT_TRUE(isa<GlobalVariable>(Depend->getArgOperand(0)));
3147   EXPECT_EQ(Depend->getArgOperand(1), GTID);
3148   EXPECT_EQ(Depend->getArgOperand(2), DependBaseAddrGEP);
3149 }
3150 
3151 TEST_F(OpenMPIRBuilderTest, OrderedDirectiveDependSink) {
3152   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3153   OpenMPIRBuilder OMPBuilder(*M);
3154   OMPBuilder.initialize();
3155   F->setName("func");
3156   IRBuilder<> Builder(BB);
3157   LLVMContext &Ctx = M->getContext();
3158 
3159   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3160 
3161   InsertPointTy AllocaIP(&F->getEntryBlock(),
3162                          F->getEntryBlock().getFirstInsertionPt());
3163 
3164   unsigned NumLoops = 2;
3165   SmallVector<Value *, 2> StoreValues;
3166   Type *LCTy = Type::getInt64Ty(Ctx);
3167   StoreValues.emplace_back(ConstantInt::get(LCTy, 1));
3168   StoreValues.emplace_back(ConstantInt::get(LCTy, 2));
3169 
3170   // Test for "#omp ordered depend(sink: vec)"
3171   Builder.restoreIP(OMPBuilder.createOrderedDepend(Builder, AllocaIP, NumLoops,
3172                                                    StoreValues, ".cnt.addr",
3173                                                    /*IsDependSource=*/false));
3174 
3175   Builder.CreateRetVoid();
3176   OMPBuilder.finalize();
3177   EXPECT_FALSE(verifyModule(*M, &errs()));
3178 
3179   AllocaInst *AllocInst = dyn_cast<AllocaInst>(&BB->front());
3180   ASSERT_NE(AllocInst, nullptr);
3181   ArrayType *ArrType = dyn_cast<ArrayType>(AllocInst->getAllocatedType());
3182   EXPECT_EQ(ArrType->getNumElements(), NumLoops);
3183   EXPECT_TRUE(
3184       AllocInst->getAllocatedType()->getArrayElementType()->isIntegerTy(64));
3185 
3186   Instruction *IterInst = dyn_cast<Instruction>(AllocInst);
3187   for (unsigned Iter = 0; Iter < NumLoops; Iter++) {
3188     GetElementPtrInst *DependAddrGEPIter =
3189         dyn_cast<GetElementPtrInst>(IterInst->getNextNode());
3190     ASSERT_NE(DependAddrGEPIter, nullptr);
3191     EXPECT_EQ(DependAddrGEPIter->getPointerOperand(), AllocInst);
3192     EXPECT_EQ(DependAddrGEPIter->getNumIndices(), (unsigned)2);
3193     auto *FirstIdx = dyn_cast<ConstantInt>(DependAddrGEPIter->getOperand(1));
3194     auto *SecondIdx = dyn_cast<ConstantInt>(DependAddrGEPIter->getOperand(2));
3195     ASSERT_NE(FirstIdx, nullptr);
3196     ASSERT_NE(SecondIdx, nullptr);
3197     EXPECT_EQ(FirstIdx->getValue(), 0);
3198     EXPECT_EQ(SecondIdx->getValue(), Iter);
3199     StoreInst *StoreValue =
3200         dyn_cast<StoreInst>(DependAddrGEPIter->getNextNode());
3201     ASSERT_NE(StoreValue, nullptr);
3202     EXPECT_EQ(StoreValue->getValueOperand(), StoreValues[Iter]);
3203     EXPECT_EQ(StoreValue->getPointerOperand(), DependAddrGEPIter);
3204     EXPECT_EQ(StoreValue->getAlign(), Align(8));
3205     IterInst = dyn_cast<Instruction>(StoreValue);
3206   }
3207 
3208   GetElementPtrInst *DependBaseAddrGEP =
3209       dyn_cast<GetElementPtrInst>(IterInst->getNextNode());
3210   ASSERT_NE(DependBaseAddrGEP, nullptr);
3211   EXPECT_EQ(DependBaseAddrGEP->getPointerOperand(), AllocInst);
3212   EXPECT_EQ(DependBaseAddrGEP->getNumIndices(), (unsigned)2);
3213   auto *FirstIdx = dyn_cast<ConstantInt>(DependBaseAddrGEP->getOperand(1));
3214   auto *SecondIdx = dyn_cast<ConstantInt>(DependBaseAddrGEP->getOperand(2));
3215   ASSERT_NE(FirstIdx, nullptr);
3216   ASSERT_NE(SecondIdx, nullptr);
3217   EXPECT_EQ(FirstIdx->getValue(), 0);
3218   EXPECT_EQ(SecondIdx->getValue(), 0);
3219 
3220   CallInst *GTID = dyn_cast<CallInst>(DependBaseAddrGEP->getNextNode());
3221   ASSERT_NE(GTID, nullptr);
3222   EXPECT_EQ(GTID->arg_size(), 1U);
3223   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
3224   EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
3225   EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
3226 
3227   CallInst *Depend = dyn_cast<CallInst>(GTID->getNextNode());
3228   ASSERT_NE(Depend, nullptr);
3229   EXPECT_EQ(Depend->arg_size(), 3U);
3230   EXPECT_EQ(Depend->getCalledFunction()->getName(), "__kmpc_doacross_wait");
3231   EXPECT_TRUE(isa<GlobalVariable>(Depend->getArgOperand(0)));
3232   EXPECT_EQ(Depend->getArgOperand(1), GTID);
3233   EXPECT_EQ(Depend->getArgOperand(2), DependBaseAddrGEP);
3234 }
3235 
3236 TEST_F(OpenMPIRBuilderTest, OrderedDirectiveThreads) {
3237   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3238   OpenMPIRBuilder OMPBuilder(*M);
3239   OMPBuilder.initialize();
3240   F->setName("func");
3241   IRBuilder<> Builder(BB);
3242 
3243   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3244 
3245   AllocaInst *PrivAI =
3246       Builder.CreateAlloca(F->arg_begin()->getType(), nullptr, "priv.inst");
3247 
3248   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3249     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3250     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3251     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3252 
3253     Builder.restoreIP(CodeGenIP);
3254     Builder.CreateStore(F->arg_begin(), PrivAI);
3255     Value *PrivLoad =
3256         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3257     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3258   };
3259 
3260   auto FiniCB = [&](InsertPointTy IP) {
3261     BasicBlock *IPBB = IP.getBlock();
3262     EXPECT_NE(IPBB->end(), IP.getPoint());
3263   };
3264 
3265   // Test for "#omp ordered [threads]"
3266   BasicBlock *EntryBB = Builder.GetInsertBlock();
3267   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3268       OMPBuilder.createOrderedThreadsSimd(Builder, BODYGENCB_WRAPPER(BodyGenCB),
3269                                           FINICB_WRAPPER(FiniCB), true);
3270   assert(AfterIP && "unexpected error");
3271   Builder.restoreIP(*AfterIP);
3272 
3273   Builder.CreateRetVoid();
3274   OMPBuilder.finalize();
3275   EXPECT_FALSE(verifyModule(*M, &errs()));
3276 
3277   EXPECT_NE(EntryBB->getTerminator(), nullptr);
3278 
3279   CallInst *OrderedEntryCI = nullptr;
3280   for (auto &EI : *EntryBB) {
3281     Instruction *Cur = &EI;
3282     if (isa<CallInst>(Cur)) {
3283       OrderedEntryCI = cast<CallInst>(Cur);
3284       if (OrderedEntryCI->getCalledFunction()->getName() == "__kmpc_ordered")
3285         break;
3286       OrderedEntryCI = nullptr;
3287     }
3288   }
3289   EXPECT_NE(OrderedEntryCI, nullptr);
3290   EXPECT_EQ(OrderedEntryCI->arg_size(), 2U);
3291   EXPECT_EQ(OrderedEntryCI->getCalledFunction()->getName(), "__kmpc_ordered");
3292   EXPECT_TRUE(isa<GlobalVariable>(OrderedEntryCI->getArgOperand(0)));
3293 
3294   CallInst *OrderedEndCI = nullptr;
3295   for (auto &FI : *EntryBB) {
3296     Instruction *Cur = &FI;
3297     if (isa<CallInst>(Cur)) {
3298       OrderedEndCI = cast<CallInst>(Cur);
3299       if (OrderedEndCI->getCalledFunction()->getName() == "__kmpc_end_ordered")
3300         break;
3301       OrderedEndCI = nullptr;
3302     }
3303   }
3304   EXPECT_NE(OrderedEndCI, nullptr);
3305   EXPECT_EQ(OrderedEndCI->arg_size(), 2U);
3306   EXPECT_TRUE(isa<GlobalVariable>(OrderedEndCI->getArgOperand(0)));
3307   EXPECT_EQ(OrderedEndCI->getArgOperand(1), OrderedEntryCI->getArgOperand(1));
3308 }
3309 
3310 TEST_F(OpenMPIRBuilderTest, OrderedDirectiveSimd) {
3311   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3312   OpenMPIRBuilder OMPBuilder(*M);
3313   OMPBuilder.initialize();
3314   F->setName("func");
3315   IRBuilder<> Builder(BB);
3316 
3317   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3318 
3319   AllocaInst *PrivAI =
3320       Builder.CreateAlloca(F->arg_begin()->getType(), nullptr, "priv.inst");
3321 
3322   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3323     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3324     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3325     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3326 
3327     Builder.restoreIP(CodeGenIP);
3328     Builder.CreateStore(F->arg_begin(), PrivAI);
3329     Value *PrivLoad =
3330         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3331     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3332   };
3333 
3334   auto FiniCB = [&](InsertPointTy IP) {
3335     BasicBlock *IPBB = IP.getBlock();
3336     EXPECT_NE(IPBB->end(), IP.getPoint());
3337   };
3338 
3339   // Test for "#omp ordered simd"
3340   BasicBlock *EntryBB = Builder.GetInsertBlock();
3341   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3342       OMPBuilder.createOrderedThreadsSimd(Builder, BODYGENCB_WRAPPER(BodyGenCB),
3343                                           FINICB_WRAPPER(FiniCB), false);
3344   assert(AfterIP && "unexpected error");
3345   Builder.restoreIP(*AfterIP);
3346 
3347   Builder.CreateRetVoid();
3348   OMPBuilder.finalize();
3349   EXPECT_FALSE(verifyModule(*M, &errs()));
3350 
3351   EXPECT_NE(EntryBB->getTerminator(), nullptr);
3352 
3353   CallInst *OrderedEntryCI = nullptr;
3354   for (auto &EI : *EntryBB) {
3355     Instruction *Cur = &EI;
3356     if (isa<CallInst>(Cur)) {
3357       OrderedEntryCI = cast<CallInst>(Cur);
3358       if (OrderedEntryCI->getCalledFunction()->getName() == "__kmpc_ordered")
3359         break;
3360       OrderedEntryCI = nullptr;
3361     }
3362   }
3363   EXPECT_EQ(OrderedEntryCI, nullptr);
3364 
3365   CallInst *OrderedEndCI = nullptr;
3366   for (auto &FI : *EntryBB) {
3367     Instruction *Cur = &FI;
3368     if (isa<CallInst>(Cur)) {
3369       OrderedEndCI = cast<CallInst>(Cur);
3370       if (OrderedEndCI->getCalledFunction()->getName() == "__kmpc_end_ordered")
3371         break;
3372       OrderedEndCI = nullptr;
3373     }
3374   }
3375   EXPECT_EQ(OrderedEndCI, nullptr);
3376 }
3377 
3378 TEST_F(OpenMPIRBuilderTest, CopyinBlocks) {
3379   OpenMPIRBuilder OMPBuilder(*M);
3380   OMPBuilder.initialize();
3381   F->setName("func");
3382   IRBuilder<> Builder(BB);
3383 
3384   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3385 
3386   IntegerType *Int32 = Type::getInt32Ty(M->getContext());
3387   AllocaInst *MasterAddress = Builder.CreateAlloca(Builder.getPtrTy());
3388   AllocaInst *PrivAddress = Builder.CreateAlloca(Builder.getPtrTy());
3389 
3390   BasicBlock *EntryBB = BB;
3391 
3392   OMPBuilder.createCopyinClauseBlocks(Builder.saveIP(), MasterAddress,
3393                                       PrivAddress, Int32, /*BranchtoEnd*/ true);
3394 
3395   BranchInst *EntryBr = dyn_cast_or_null<BranchInst>(EntryBB->getTerminator());
3396 
3397   EXPECT_NE(EntryBr, nullptr);
3398   EXPECT_TRUE(EntryBr->isConditional());
3399 
3400   BasicBlock *NotMasterBB = EntryBr->getSuccessor(0);
3401   BasicBlock *CopyinEnd = EntryBr->getSuccessor(1);
3402   CmpInst *CMP = dyn_cast_or_null<CmpInst>(EntryBr->getCondition());
3403 
3404   EXPECT_NE(CMP, nullptr);
3405   EXPECT_NE(NotMasterBB, nullptr);
3406   EXPECT_NE(CopyinEnd, nullptr);
3407 
3408   BranchInst *NotMasterBr =
3409       dyn_cast_or_null<BranchInst>(NotMasterBB->getTerminator());
3410   EXPECT_NE(NotMasterBr, nullptr);
3411   EXPECT_FALSE(NotMasterBr->isConditional());
3412   EXPECT_EQ(CopyinEnd, NotMasterBr->getSuccessor(0));
3413 }
3414 
3415 TEST_F(OpenMPIRBuilderTest, SingleDirective) {
3416   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3417   OpenMPIRBuilder OMPBuilder(*M);
3418   OMPBuilder.initialize();
3419   F->setName("func");
3420   IRBuilder<> Builder(BB);
3421 
3422   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3423 
3424   AllocaInst *PrivAI = nullptr;
3425 
3426   BasicBlock *EntryBB = nullptr;
3427   BasicBlock *ThenBB = nullptr;
3428 
3429   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3430     if (AllocaIP.isSet())
3431       Builder.restoreIP(AllocaIP);
3432     else
3433       Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
3434     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
3435     Builder.CreateStore(F->arg_begin(), PrivAI);
3436 
3437     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3438     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3439     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3440 
3441     Builder.restoreIP(CodeGenIP);
3442 
3443     // collect some info for checks later
3444     ThenBB = Builder.GetInsertBlock();
3445     EntryBB = ThenBB->getUniquePredecessor();
3446 
3447     // simple instructions for body
3448     Value *PrivLoad =
3449         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3450     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3451   };
3452 
3453   auto FiniCB = [&](InsertPointTy IP) {
3454     BasicBlock *IPBB = IP.getBlock();
3455     EXPECT_NE(IPBB->end(), IP.getPoint());
3456   };
3457 
3458   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3459       OMPBuilder.createSingle(Builder, BODYGENCB_WRAPPER(BodyGenCB),
3460                               FINICB_WRAPPER(FiniCB), /*IsNowait*/ false);
3461   assert(AfterIP && "unexpected error");
3462   Builder.restoreIP(*AfterIP);
3463   Value *EntryBBTI = EntryBB->getTerminator();
3464   EXPECT_NE(EntryBBTI, nullptr);
3465   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
3466   BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
3467   EXPECT_TRUE(EntryBr->isConditional());
3468   EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
3469   BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
3470   EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
3471 
3472   CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
3473   EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
3474 
3475   CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
3476   EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
3477   EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
3478   EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
3479 
3480   CallInst *SingleEndCI = nullptr;
3481   for (auto &FI : *ThenBB) {
3482     Instruction *cur = &FI;
3483     if (isa<CallInst>(cur)) {
3484       SingleEndCI = cast<CallInst>(cur);
3485       if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
3486         break;
3487       SingleEndCI = nullptr;
3488     }
3489   }
3490   EXPECT_NE(SingleEndCI, nullptr);
3491   EXPECT_EQ(SingleEndCI->arg_size(), 2U);
3492   EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
3493   EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
3494 
3495   bool FoundBarrier = false;
3496   for (auto &FI : *ExitBB) {
3497     Instruction *cur = &FI;
3498     if (auto CI = dyn_cast<CallInst>(cur)) {
3499       if (CI->getCalledFunction()->getName() == "__kmpc_barrier") {
3500         FoundBarrier = true;
3501         break;
3502       }
3503     }
3504   }
3505   EXPECT_TRUE(FoundBarrier);
3506 }
3507 
3508 TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
3509   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3510   OpenMPIRBuilder OMPBuilder(*M);
3511   OMPBuilder.initialize();
3512   F->setName("func");
3513   IRBuilder<> Builder(BB);
3514 
3515   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3516 
3517   AllocaInst *PrivAI = nullptr;
3518 
3519   BasicBlock *EntryBB = nullptr;
3520   BasicBlock *ThenBB = nullptr;
3521 
3522   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3523     if (AllocaIP.isSet())
3524       Builder.restoreIP(AllocaIP);
3525     else
3526       Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
3527     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
3528     Builder.CreateStore(F->arg_begin(), PrivAI);
3529 
3530     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3531     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3532     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3533 
3534     Builder.restoreIP(CodeGenIP);
3535 
3536     // collect some info for checks later
3537     ThenBB = Builder.GetInsertBlock();
3538     EntryBB = ThenBB->getUniquePredecessor();
3539 
3540     // simple instructions for body
3541     Value *PrivLoad =
3542         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3543     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3544   };
3545 
3546   auto FiniCB = [&](InsertPointTy IP) {
3547     BasicBlock *IPBB = IP.getBlock();
3548     EXPECT_NE(IPBB->end(), IP.getPoint());
3549   };
3550 
3551   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3552       OMPBuilder.createSingle(Builder, BODYGENCB_WRAPPER(BodyGenCB),
3553                               FINICB_WRAPPER(FiniCB), /*IsNowait*/ true);
3554   assert(AfterIP && "unexpected error");
3555   Builder.restoreIP(*AfterIP);
3556   Value *EntryBBTI = EntryBB->getTerminator();
3557   EXPECT_NE(EntryBBTI, nullptr);
3558   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
3559   BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
3560   EXPECT_TRUE(EntryBr->isConditional());
3561   EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
3562   BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
3563   EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
3564 
3565   CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
3566   EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
3567 
3568   CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
3569   EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
3570   EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
3571   EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
3572 
3573   CallInst *SingleEndCI = nullptr;
3574   for (auto &FI : *ThenBB) {
3575     Instruction *cur = &FI;
3576     if (isa<CallInst>(cur)) {
3577       SingleEndCI = cast<CallInst>(cur);
3578       if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
3579         break;
3580       SingleEndCI = nullptr;
3581     }
3582   }
3583   EXPECT_NE(SingleEndCI, nullptr);
3584   EXPECT_EQ(SingleEndCI->arg_size(), 2U);
3585   EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
3586   EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
3587 
3588   CallInst *ExitBarrier = nullptr;
3589   for (auto &FI : *ExitBB) {
3590     Instruction *cur = &FI;
3591     if (auto CI = dyn_cast<CallInst>(cur)) {
3592       if (CI->getCalledFunction()->getName() == "__kmpc_barrier") {
3593         ExitBarrier = CI;
3594         break;
3595       }
3596     }
3597   }
3598   EXPECT_EQ(ExitBarrier, nullptr);
3599 }
3600 
3601 // Helper class to check each instruction of a BB.
3602 class BBInstIter {
3603   BasicBlock *BB;
3604   BasicBlock::iterator BBI;
3605 
3606 public:
3607   BBInstIter(BasicBlock *BB) : BB(BB), BBI(BB->begin()) {}
3608 
3609   bool hasNext() const { return BBI != BB->end(); }
3610 
3611   template <typename InstTy> InstTy *next() {
3612     if (!hasNext())
3613       return nullptr;
3614     Instruction *Cur = &*BBI++;
3615     if (!isa<InstTy>(Cur))
3616       return nullptr;
3617     return cast<InstTy>(Cur);
3618   }
3619 };
3620 
3621 TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
3622   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3623   OpenMPIRBuilder OMPBuilder(*M);
3624   OMPBuilder.initialize();
3625   F->setName("func");
3626   IRBuilder<> Builder(BB);
3627 
3628   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3629 
3630   AllocaInst *PrivAI = nullptr;
3631 
3632   BasicBlock *EntryBB = nullptr;
3633   BasicBlock *ThenBB = nullptr;
3634 
3635   Value *CPVar = Builder.CreateAlloca(F->arg_begin()->getType());
3636   Builder.CreateStore(F->arg_begin(), CPVar);
3637 
3638   FunctionType *CopyFuncTy = FunctionType::get(
3639       Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, false);
3640   Function *CopyFunc =
3641       Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
3642 
3643   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3644     if (AllocaIP.isSet())
3645       Builder.restoreIP(AllocaIP);
3646     else
3647       Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
3648     PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
3649     Builder.CreateStore(F->arg_begin(), PrivAI);
3650 
3651     llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3652     llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3653     EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3654 
3655     Builder.restoreIP(CodeGenIP);
3656 
3657     // collect some info for checks later
3658     ThenBB = Builder.GetInsertBlock();
3659     EntryBB = ThenBB->getUniquePredecessor();
3660 
3661     // simple instructions for body
3662     Value *PrivLoad =
3663         Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3664     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3665   };
3666 
3667   auto FiniCB = [&](InsertPointTy IP) {
3668     BasicBlock *IPBB = IP.getBlock();
3669     // IP must be before the unconditional branch to ExitBB
3670     EXPECT_NE(IPBB->end(), IP.getPoint());
3671   };
3672 
3673   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createSingle(
3674       Builder, BODYGENCB_WRAPPER(BodyGenCB), FINICB_WRAPPER(FiniCB),
3675       /*IsNowait*/ false, {CPVar}, {CopyFunc});
3676   assert(AfterIP && "unexpected error");
3677   Builder.restoreIP(*AfterIP);
3678   Value *EntryBBTI = EntryBB->getTerminator();
3679   EXPECT_NE(EntryBBTI, nullptr);
3680   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
3681   BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
3682   EXPECT_TRUE(EntryBr->isConditional());
3683   EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
3684   BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
3685   EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
3686 
3687   CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
3688   EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
3689 
3690   CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
3691   EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
3692   EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
3693   EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
3694 
3695   // check ThenBB
3696   BBInstIter ThenBBI(ThenBB);
3697   // load PrivAI
3698   auto *PrivLI = ThenBBI.next<LoadInst>();
3699   EXPECT_NE(PrivLI, nullptr);
3700   EXPECT_EQ(PrivLI->getPointerOperand(), PrivAI);
3701   // icmp
3702   EXPECT_TRUE(ThenBBI.next<ICmpInst>());
3703   // store 1, DidIt
3704   auto *DidItSI = ThenBBI.next<StoreInst>();
3705   EXPECT_NE(DidItSI, nullptr);
3706   EXPECT_EQ(DidItSI->getValueOperand(),
3707             ConstantInt::get(Type::getInt32Ty(Ctx), 1));
3708   Value *DidIt = DidItSI->getPointerOperand();
3709   // call __kmpc_end_single
3710   auto *SingleEndCI = ThenBBI.next<CallInst>();
3711   EXPECT_NE(SingleEndCI, nullptr);
3712   EXPECT_EQ(SingleEndCI->getCalledFunction()->getName(), "__kmpc_end_single");
3713   EXPECT_EQ(SingleEndCI->arg_size(), 2U);
3714   EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
3715   EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
3716   // br ExitBB
3717   auto *ExitBBBI = ThenBBI.next<BranchInst>();
3718   EXPECT_NE(ExitBBBI, nullptr);
3719   EXPECT_TRUE(ExitBBBI->isUnconditional());
3720   EXPECT_EQ(ExitBBBI->getOperand(0), ExitBB);
3721   EXPECT_FALSE(ThenBBI.hasNext());
3722 
3723   // check ExitBB
3724   BBInstIter ExitBBI(ExitBB);
3725   // call __kmpc_global_thread_num
3726   auto *ThreadNumCI = ExitBBI.next<CallInst>();
3727   EXPECT_NE(ThreadNumCI, nullptr);
3728   EXPECT_EQ(ThreadNumCI->getCalledFunction()->getName(),
3729             "__kmpc_global_thread_num");
3730   // load DidIt
3731   auto *DidItLI = ExitBBI.next<LoadInst>();
3732   EXPECT_NE(DidItLI, nullptr);
3733   EXPECT_EQ(DidItLI->getPointerOperand(), DidIt);
3734   // call __kmpc_copyprivate
3735   auto *CopyPrivateCI = ExitBBI.next<CallInst>();
3736   EXPECT_NE(CopyPrivateCI, nullptr);
3737   EXPECT_EQ(CopyPrivateCI->arg_size(), 6U);
3738   EXPECT_TRUE(isa<AllocaInst>(CopyPrivateCI->getArgOperand(3)));
3739   EXPECT_EQ(CopyPrivateCI->getArgOperand(3), CPVar);
3740   EXPECT_TRUE(isa<Function>(CopyPrivateCI->getArgOperand(4)));
3741   EXPECT_EQ(CopyPrivateCI->getArgOperand(4), CopyFunc);
3742   EXPECT_TRUE(isa<LoadInst>(CopyPrivateCI->getArgOperand(5)));
3743   DidItLI = cast<LoadInst>(CopyPrivateCI->getArgOperand(5));
3744   EXPECT_EQ(DidItLI->getOperand(0), DidIt);
3745   EXPECT_FALSE(ExitBBI.hasNext());
3746 }
3747 
3748 TEST_F(OpenMPIRBuilderTest, OMPAtomicReadFlt) {
3749   OpenMPIRBuilder OMPBuilder(*M);
3750   OMPBuilder.initialize();
3751   F->setName("func");
3752   IRBuilder<> Builder(BB);
3753 
3754   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3755 
3756   Type *Float32 = Type::getFloatTy(M->getContext());
3757   AllocaInst *XVal = Builder.CreateAlloca(Float32);
3758   XVal->setName("AtomicVar");
3759   AllocaInst *VVal = Builder.CreateAlloca(Float32);
3760   VVal->setName("AtomicRead");
3761   AtomicOrdering AO = AtomicOrdering::Monotonic;
3762   OpenMPIRBuilder::AtomicOpValue X = {XVal, Float32, false, false};
3763   OpenMPIRBuilder::AtomicOpValue V = {VVal, Float32, false, false};
3764 
3765   Builder.restoreIP(OMPBuilder.createAtomicRead(Loc, X, V, AO));
3766 
3767   IntegerType *IntCastTy =
3768       IntegerType::get(M->getContext(), Float32->getScalarSizeInBits());
3769 
3770   LoadInst *AtomicLoad = cast<LoadInst>(VVal->getNextNode());
3771   EXPECT_TRUE(AtomicLoad->isAtomic());
3772   EXPECT_EQ(AtomicLoad->getPointerOperand(), XVal);
3773 
3774   BitCastInst *CastToFlt = cast<BitCastInst>(AtomicLoad->getNextNode());
3775   EXPECT_EQ(CastToFlt->getSrcTy(), IntCastTy);
3776   EXPECT_EQ(CastToFlt->getDestTy(), Float32);
3777   EXPECT_EQ(CastToFlt->getOperand(0), AtomicLoad);
3778 
3779   StoreInst *StoreofAtomic = cast<StoreInst>(CastToFlt->getNextNode());
3780   EXPECT_EQ(StoreofAtomic->getValueOperand(), CastToFlt);
3781   EXPECT_EQ(StoreofAtomic->getPointerOperand(), VVal);
3782 
3783   Builder.CreateRetVoid();
3784   OMPBuilder.finalize();
3785   EXPECT_FALSE(verifyModule(*M, &errs()));
3786 }
3787 
3788 TEST_F(OpenMPIRBuilderTest, OMPAtomicReadInt) {
3789   OpenMPIRBuilder OMPBuilder(*M);
3790   OMPBuilder.initialize();
3791   F->setName("func");
3792   IRBuilder<> Builder(BB);
3793 
3794   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3795 
3796   IntegerType *Int32 = Type::getInt32Ty(M->getContext());
3797   AllocaInst *XVal = Builder.CreateAlloca(Int32);
3798   XVal->setName("AtomicVar");
3799   AllocaInst *VVal = Builder.CreateAlloca(Int32);
3800   VVal->setName("AtomicRead");
3801   AtomicOrdering AO = AtomicOrdering::Monotonic;
3802   OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, false, false};
3803   OpenMPIRBuilder::AtomicOpValue V = {VVal, Int32, false, false};
3804 
3805   BasicBlock *EntryBB = BB;
3806 
3807   Builder.restoreIP(OMPBuilder.createAtomicRead(Loc, X, V, AO));
3808   LoadInst *AtomicLoad = nullptr;
3809   StoreInst *StoreofAtomic = nullptr;
3810 
3811   for (Instruction &Cur : *EntryBB) {
3812     if (isa<LoadInst>(Cur)) {
3813       AtomicLoad = cast<LoadInst>(&Cur);
3814       if (AtomicLoad->getPointerOperand() == XVal)
3815         continue;
3816       AtomicLoad = nullptr;
3817     } else if (isa<StoreInst>(Cur)) {
3818       StoreofAtomic = cast<StoreInst>(&Cur);
3819       if (StoreofAtomic->getPointerOperand() == VVal)
3820         continue;
3821       StoreofAtomic = nullptr;
3822     }
3823   }
3824 
3825   EXPECT_NE(AtomicLoad, nullptr);
3826   EXPECT_TRUE(AtomicLoad->isAtomic());
3827 
3828   EXPECT_NE(StoreofAtomic, nullptr);
3829   EXPECT_EQ(StoreofAtomic->getValueOperand(), AtomicLoad);
3830 
3831   Builder.CreateRetVoid();
3832   OMPBuilder.finalize();
3833 
3834   EXPECT_FALSE(verifyModule(*M, &errs()));
3835 }
3836 
3837 TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteFlt) {
3838   OpenMPIRBuilder OMPBuilder(*M);
3839   OMPBuilder.initialize();
3840   F->setName("func");
3841   IRBuilder<> Builder(BB);
3842 
3843   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3844 
3845   LLVMContext &Ctx = M->getContext();
3846   Type *Float32 = Type::getFloatTy(Ctx);
3847   AllocaInst *XVal = Builder.CreateAlloca(Float32);
3848   XVal->setName("AtomicVar");
3849   OpenMPIRBuilder::AtomicOpValue X = {XVal, Float32, false, false};
3850   AtomicOrdering AO = AtomicOrdering::Monotonic;
3851   Constant *ValToWrite = ConstantFP::get(Float32, 1.0);
3852 
3853   Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
3854 
3855   IntegerType *IntCastTy =
3856       IntegerType::get(M->getContext(), Float32->getScalarSizeInBits());
3857 
3858   Value *ExprCast = Builder.CreateBitCast(ValToWrite, IntCastTy);
3859 
3860   StoreInst *StoreofAtomic = cast<StoreInst>(XVal->getNextNode());
3861   EXPECT_EQ(StoreofAtomic->getValueOperand(), ExprCast);
3862   EXPECT_EQ(StoreofAtomic->getPointerOperand(), XVal);
3863   EXPECT_TRUE(StoreofAtomic->isAtomic());
3864 
3865   Builder.CreateRetVoid();
3866   OMPBuilder.finalize();
3867   EXPECT_FALSE(verifyModule(*M, &errs()));
3868 }
3869 
3870 TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteInt) {
3871   OpenMPIRBuilder OMPBuilder(*M);
3872   OMPBuilder.initialize();
3873   F->setName("func");
3874   IRBuilder<> Builder(BB);
3875 
3876   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3877 
3878   LLVMContext &Ctx = M->getContext();
3879   IntegerType *Int32 = Type::getInt32Ty(Ctx);
3880   AllocaInst *XVal = Builder.CreateAlloca(Int32);
3881   XVal->setName("AtomicVar");
3882   OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, false, false};
3883   AtomicOrdering AO = AtomicOrdering::Monotonic;
3884   ConstantInt *ValToWrite = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
3885 
3886   BasicBlock *EntryBB = BB;
3887 
3888   Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
3889 
3890   StoreInst *StoreofAtomic = nullptr;
3891 
3892   for (Instruction &Cur : *EntryBB) {
3893     if (isa<StoreInst>(Cur)) {
3894       StoreofAtomic = cast<StoreInst>(&Cur);
3895       if (StoreofAtomic->getPointerOperand() == XVal)
3896         continue;
3897       StoreofAtomic = nullptr;
3898     }
3899   }
3900 
3901   EXPECT_NE(StoreofAtomic, nullptr);
3902   EXPECT_TRUE(StoreofAtomic->isAtomic());
3903   EXPECT_EQ(StoreofAtomic->getValueOperand(), ValToWrite);
3904 
3905   Builder.CreateRetVoid();
3906   OMPBuilder.finalize();
3907   EXPECT_FALSE(verifyModule(*M, &errs()));
3908 }
3909 
3910 TEST_F(OpenMPIRBuilderTest, OMPAtomicUpdate) {
3911   OpenMPIRBuilder OMPBuilder(*M);
3912   OMPBuilder.initialize();
3913   F->setName("func");
3914   IRBuilder<> Builder(BB);
3915 
3916   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3917 
3918   IntegerType *Int32 = Type::getInt32Ty(M->getContext());
3919   AllocaInst *XVal = Builder.CreateAlloca(Int32);
3920   XVal->setName("AtomicVar");
3921   Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
3922   OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, false, false};
3923   AtomicOrdering AO = AtomicOrdering::Monotonic;
3924   ConstantInt *ConstVal = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
3925   Value *Expr = nullptr;
3926   AtomicRMWInst::BinOp RMWOp = AtomicRMWInst::Sub;
3927   bool IsXLHSInRHSPart = false;
3928 
3929   BasicBlock *EntryBB = BB;
3930   OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
3931                                           EntryBB->getFirstInsertionPt());
3932   Value *Sub = nullptr;
3933 
3934   auto UpdateOp = [&](Value *Atomic, IRBuilder<> &IRB) {
3935     Sub = IRB.CreateSub(ConstVal, Atomic);
3936     return Sub;
3937   };
3938   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createAtomicUpdate(
3939       Builder, AllocaIP, X, Expr, AO, RMWOp, UpdateOp, IsXLHSInRHSPart);
3940   assert(AfterIP && "unexpected error");
3941   Builder.restoreIP(*AfterIP);
3942   BasicBlock *ContBB = EntryBB->getSingleSuccessor();
3943   BranchInst *ContTI = dyn_cast<BranchInst>(ContBB->getTerminator());
3944   EXPECT_NE(ContTI, nullptr);
3945   BasicBlock *EndBB = ContTI->getSuccessor(0);
3946   EXPECT_TRUE(ContTI->isConditional());
3947   EXPECT_EQ(ContTI->getSuccessor(1), ContBB);
3948   EXPECT_NE(EndBB, nullptr);
3949 
3950   PHINode *Phi = dyn_cast<PHINode>(&ContBB->front());
3951   EXPECT_NE(Phi, nullptr);
3952   EXPECT_EQ(Phi->getNumIncomingValues(), 2U);
3953   EXPECT_EQ(Phi->getIncomingBlock(0), EntryBB);
3954   EXPECT_EQ(Phi->getIncomingBlock(1), ContBB);
3955 
3956   EXPECT_EQ(Sub->getNumUses(), 1U);
3957   StoreInst *St = dyn_cast<StoreInst>(Sub->user_back());
3958   AllocaInst *UpdateTemp = dyn_cast<AllocaInst>(St->getPointerOperand());
3959 
3960   ExtractValueInst *ExVI1 =
3961       dyn_cast<ExtractValueInst>(Phi->getIncomingValueForBlock(ContBB));
3962   EXPECT_NE(ExVI1, nullptr);
3963   AtomicCmpXchgInst *CmpExchg =
3964       dyn_cast<AtomicCmpXchgInst>(ExVI1->getAggregateOperand());
3965   EXPECT_NE(CmpExchg, nullptr);
3966   EXPECT_EQ(CmpExchg->getPointerOperand(), XVal);
3967   EXPECT_EQ(CmpExchg->getCompareOperand(), Phi);
3968   EXPECT_EQ(CmpExchg->getSuccessOrdering(), AtomicOrdering::Monotonic);
3969 
3970   LoadInst *Ld = dyn_cast<LoadInst>(CmpExchg->getNewValOperand());
3971   EXPECT_NE(Ld, nullptr);
3972   EXPECT_EQ(UpdateTemp, Ld->getPointerOperand());
3973 
3974   Builder.CreateRetVoid();
3975   OMPBuilder.finalize();
3976   EXPECT_FALSE(verifyModule(*M, &errs()));
3977 }
3978 
3979 TEST_F(OpenMPIRBuilderTest, OMPAtomicUpdateFloat) {
3980   OpenMPIRBuilder OMPBuilder(*M);
3981   OMPBuilder.initialize();
3982   F->setName("func");
3983   IRBuilder<> Builder(BB);
3984 
3985   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3986 
3987   Type *FloatTy = Type::getFloatTy(M->getContext());
3988   AllocaInst *XVal = Builder.CreateAlloca(FloatTy);
3989   XVal->setName("AtomicVar");
3990   Builder.CreateStore(ConstantFP::get(Type::getFloatTy(Ctx), 0.0), XVal);
3991   OpenMPIRBuilder::AtomicOpValue X = {XVal, FloatTy, false, false};
3992   AtomicOrdering AO = AtomicOrdering::Monotonic;
3993   Constant *ConstVal = ConstantFP::get(Type::getFloatTy(Ctx), 1.0);
3994   Value *Expr = nullptr;
3995   AtomicRMWInst::BinOp RMWOp = AtomicRMWInst::FSub;
3996   bool IsXLHSInRHSPart = false;
3997 
3998   BasicBlock *EntryBB = BB;
3999   OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
4000                                           EntryBB->getFirstInsertionPt());
4001   Value *Sub = nullptr;
4002 
4003   auto UpdateOp = [&](Value *Atomic, IRBuilder<> &IRB) {
4004     Sub = IRB.CreateFSub(ConstVal, Atomic);
4005     return Sub;
4006   };
4007   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createAtomicUpdate(
4008       Builder, AllocaIP, X, Expr, AO, RMWOp, UpdateOp, IsXLHSInRHSPart);
4009   assert(AfterIP && "unexpected error");
4010   Builder.restoreIP(*AfterIP);
4011   BasicBlock *ContBB = EntryBB->getSingleSuccessor();
4012   BranchInst *ContTI = dyn_cast<BranchInst>(ContBB->getTerminator());
4013   EXPECT_NE(ContTI, nullptr);
4014   BasicBlock *EndBB = ContTI->getSuccessor(0);
4015   EXPECT_TRUE(ContTI->isConditional());
4016   EXPECT_EQ(ContTI->getSuccessor(1), ContBB);
4017   EXPECT_NE(EndBB, nullptr);
4018 
4019   PHINode *Phi = dyn_cast<PHINode>(&ContBB->front());
4020   EXPECT_NE(Phi, nullptr);
4021   EXPECT_EQ(Phi->getNumIncomingValues(), 2U);
4022   EXPECT_EQ(Phi->getIncomingBlock(0), EntryBB);
4023   EXPECT_EQ(Phi->getIncomingBlock(1), ContBB);
4024 
4025   EXPECT_EQ(Sub->getNumUses(), 1U);
4026   StoreInst *St = dyn_cast<StoreInst>(Sub->user_back());
4027   AllocaInst *UpdateTemp = dyn_cast<AllocaInst>(St->getPointerOperand());
4028 
4029   ExtractValueInst *ExVI1 =
4030       dyn_cast<ExtractValueInst>(Phi->getIncomingValueForBlock(ContBB));
4031   EXPECT_NE(ExVI1, nullptr);
4032   AtomicCmpXchgInst *CmpExchg =
4033       dyn_cast<AtomicCmpXchgInst>(ExVI1->getAggregateOperand());
4034   EXPECT_NE(CmpExchg, nullptr);
4035   EXPECT_EQ(CmpExchg->getPointerOperand(), XVal);
4036   EXPECT_EQ(CmpExchg->getCompareOperand(), Phi);
4037   EXPECT_EQ(CmpExchg->getSuccessOrdering(), AtomicOrdering::Monotonic);
4038 
4039   LoadInst *Ld = dyn_cast<LoadInst>(CmpExchg->getNewValOperand());
4040   EXPECT_NE(Ld, nullptr);
4041   EXPECT_EQ(UpdateTemp, Ld->getPointerOperand());
4042   Builder.CreateRetVoid();
4043   OMPBuilder.finalize();
4044   EXPECT_FALSE(verifyModule(*M, &errs()));
4045 }
4046 
4047 TEST_F(OpenMPIRBuilderTest, OMPAtomicUpdateIntr) {
4048   OpenMPIRBuilder OMPBuilder(*M);
4049   OMPBuilder.initialize();
4050   F->setName("func");
4051   IRBuilder<> Builder(BB);
4052 
4053   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4054 
4055   Type *IntTy = Type::getInt32Ty(M->getContext());
4056   AllocaInst *XVal = Builder.CreateAlloca(IntTy);
4057   XVal->setName("AtomicVar");
4058   Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0), XVal);
4059   OpenMPIRBuilder::AtomicOpValue X = {XVal, IntTy, false, false};
4060   AtomicOrdering AO = AtomicOrdering::Monotonic;
4061   Constant *ConstVal = ConstantInt::get(Type::getInt32Ty(Ctx), 1);
4062   Value *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1);
4063   AtomicRMWInst::BinOp RMWOp = AtomicRMWInst::UMax;
4064   bool IsXLHSInRHSPart = false;
4065 
4066   BasicBlock *EntryBB = BB;
4067   OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
4068                                           EntryBB->getFirstInsertionPt());
4069   Value *Sub = nullptr;
4070 
4071   auto UpdateOp = [&](Value *Atomic, IRBuilder<> &IRB) {
4072     Sub = IRB.CreateSub(ConstVal, Atomic);
4073     return Sub;
4074   };
4075   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createAtomicUpdate(
4076       Builder, AllocaIP, X, Expr, AO, RMWOp, UpdateOp, IsXLHSInRHSPart);
4077   assert(AfterIP && "unexpected error");
4078   Builder.restoreIP(*AfterIP);
4079   BasicBlock *ContBB = EntryBB->getSingleSuccessor();
4080   BranchInst *ContTI = dyn_cast<BranchInst>(ContBB->getTerminator());
4081   EXPECT_NE(ContTI, nullptr);
4082   BasicBlock *EndBB = ContTI->getSuccessor(0);
4083   EXPECT_TRUE(ContTI->isConditional());
4084   EXPECT_EQ(ContTI->getSuccessor(1), ContBB);
4085   EXPECT_NE(EndBB, nullptr);
4086 
4087   PHINode *Phi = dyn_cast<PHINode>(&ContBB->front());
4088   EXPECT_NE(Phi, nullptr);
4089   EXPECT_EQ(Phi->getNumIncomingValues(), 2U);
4090   EXPECT_EQ(Phi->getIncomingBlock(0), EntryBB);
4091   EXPECT_EQ(Phi->getIncomingBlock(1), ContBB);
4092 
4093   EXPECT_EQ(Sub->getNumUses(), 1U);
4094   StoreInst *St = dyn_cast<StoreInst>(Sub->user_back());
4095   AllocaInst *UpdateTemp = dyn_cast<AllocaInst>(St->getPointerOperand());
4096 
4097   ExtractValueInst *ExVI1 =
4098       dyn_cast<ExtractValueInst>(Phi->getIncomingValueForBlock(ContBB));
4099   EXPECT_NE(ExVI1, nullptr);
4100   AtomicCmpXchgInst *CmpExchg =
4101       dyn_cast<AtomicCmpXchgInst>(ExVI1->getAggregateOperand());
4102   EXPECT_NE(CmpExchg, nullptr);
4103   EXPECT_EQ(CmpExchg->getPointerOperand(), XVal);
4104   EXPECT_EQ(CmpExchg->getCompareOperand(), Phi);
4105   EXPECT_EQ(CmpExchg->getSuccessOrdering(), AtomicOrdering::Monotonic);
4106 
4107   LoadInst *Ld = dyn_cast<LoadInst>(CmpExchg->getNewValOperand());
4108   EXPECT_NE(Ld, nullptr);
4109   EXPECT_EQ(UpdateTemp, Ld->getPointerOperand());
4110 
4111   Builder.CreateRetVoid();
4112   OMPBuilder.finalize();
4113   EXPECT_FALSE(verifyModule(*M, &errs()));
4114 }
4115 
4116 TEST_F(OpenMPIRBuilderTest, OMPAtomicCapture) {
4117   OpenMPIRBuilder OMPBuilder(*M);
4118   OMPBuilder.initialize();
4119   F->setName("func");
4120   IRBuilder<> Builder(BB);
4121 
4122   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4123 
4124   LLVMContext &Ctx = M->getContext();
4125   IntegerType *Int32 = Type::getInt32Ty(Ctx);
4126   AllocaInst *XVal = Builder.CreateAlloca(Int32);
4127   XVal->setName("AtomicVar");
4128   AllocaInst *VVal = Builder.CreateAlloca(Int32);
4129   VVal->setName("AtomicCapTar");
4130   StoreInst *Init =
4131       Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
4132 
4133   OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, false, false};
4134   OpenMPIRBuilder::AtomicOpValue V = {VVal, Int32, false, false};
4135   AtomicOrdering AO = AtomicOrdering::Monotonic;
4136   ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
4137   AtomicRMWInst::BinOp RMWOp = AtomicRMWInst::Add;
4138   bool IsXLHSInRHSPart = true;
4139   bool IsPostfixUpdate = true;
4140   bool UpdateExpr = true;
4141 
4142   BasicBlock *EntryBB = BB;
4143   OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
4144                                           EntryBB->getFirstInsertionPt());
4145 
4146   // integer update - not used
4147   auto UpdateOp = [&](Value *Atomic, IRBuilder<> &IRB) { return nullptr; };
4148 
4149   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4150       OMPBuilder.createAtomicCapture(Builder, AllocaIP, X, V, Expr, AO, RMWOp,
4151                                      UpdateOp, UpdateExpr, IsPostfixUpdate,
4152                                      IsXLHSInRHSPart);
4153   assert(AfterIP && "unexpected error");
4154   Builder.restoreIP(*AfterIP);
4155   EXPECT_EQ(EntryBB->getParent()->size(), 1U);
4156   AtomicRMWInst *ARWM = dyn_cast<AtomicRMWInst>(Init->getNextNode());
4157   EXPECT_NE(ARWM, nullptr);
4158   EXPECT_EQ(ARWM->getPointerOperand(), XVal);
4159   EXPECT_EQ(ARWM->getOperation(), RMWOp);
4160   StoreInst *St = dyn_cast<StoreInst>(ARWM->user_back());
4161   EXPECT_NE(St, nullptr);
4162   EXPECT_EQ(St->getPointerOperand(), VVal);
4163 
4164   Builder.CreateRetVoid();
4165   OMPBuilder.finalize();
4166   EXPECT_FALSE(verifyModule(*M, &errs()));
4167 }
4168 
4169 TEST_F(OpenMPIRBuilderTest, OMPAtomicCompare) {
4170   OpenMPIRBuilder OMPBuilder(*M);
4171   OMPBuilder.initialize();
4172   F->setName("func");
4173   IRBuilder<> Builder(BB);
4174 
4175   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4176 
4177   LLVMContext &Ctx = M->getContext();
4178   IntegerType *Int32 = Type::getInt32Ty(Ctx);
4179   AllocaInst *XVal = Builder.CreateAlloca(Int32);
4180   XVal->setName("x");
4181   StoreInst *Init =
4182       Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
4183 
4184   OpenMPIRBuilder::AtomicOpValue XSigned = {XVal, Int32, true, false};
4185   OpenMPIRBuilder::AtomicOpValue XUnsigned = {XVal, Int32, false, false};
4186   // V and R are not used in atomic compare
4187   OpenMPIRBuilder::AtomicOpValue V = {nullptr, nullptr, false, false};
4188   OpenMPIRBuilder::AtomicOpValue R = {nullptr, nullptr, false, false};
4189   AtomicOrdering AO = AtomicOrdering::Monotonic;
4190   ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
4191   ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
4192   OMPAtomicCompareOp OpMax = OMPAtomicCompareOp::MAX;
4193   OMPAtomicCompareOp OpEQ = OMPAtomicCompareOp::EQ;
4194 
4195   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4196       Builder, XSigned, V, R, Expr, nullptr, AO, OpMax, true, false, false));
4197   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4198       Builder, XUnsigned, V, R, Expr, nullptr, AO, OpMax, false, false, false));
4199   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4200       Builder, XSigned, V, R, Expr, D, AO, OpEQ, true, false, false));
4201 
4202   BasicBlock *EntryBB = BB;
4203   EXPECT_EQ(EntryBB->getParent()->size(), 1U);
4204   EXPECT_EQ(EntryBB->size(), 5U);
4205 
4206   AtomicRMWInst *ARWM1 = dyn_cast<AtomicRMWInst>(Init->getNextNode());
4207   EXPECT_NE(ARWM1, nullptr);
4208   EXPECT_EQ(ARWM1->getPointerOperand(), XVal);
4209   EXPECT_EQ(ARWM1->getValOperand(), Expr);
4210   EXPECT_EQ(ARWM1->getOperation(), AtomicRMWInst::Min);
4211 
4212   AtomicRMWInst *ARWM2 = dyn_cast<AtomicRMWInst>(ARWM1->getNextNode());
4213   EXPECT_NE(ARWM2, nullptr);
4214   EXPECT_EQ(ARWM2->getPointerOperand(), XVal);
4215   EXPECT_EQ(ARWM2->getValOperand(), Expr);
4216   EXPECT_EQ(ARWM2->getOperation(), AtomicRMWInst::UMax);
4217 
4218   AtomicCmpXchgInst *AXCHG = dyn_cast<AtomicCmpXchgInst>(ARWM2->getNextNode());
4219   EXPECT_NE(AXCHG, nullptr);
4220   EXPECT_EQ(AXCHG->getPointerOperand(), XVal);
4221   EXPECT_EQ(AXCHG->getCompareOperand(), Expr);
4222   EXPECT_EQ(AXCHG->getNewValOperand(), D);
4223 
4224   Builder.CreateRetVoid();
4225   OMPBuilder.finalize();
4226   EXPECT_FALSE(verifyModule(*M, &errs()));
4227 }
4228 
4229 TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) {
4230   OpenMPIRBuilder OMPBuilder(*M);
4231   OMPBuilder.initialize();
4232   F->setName("func");
4233   IRBuilder<> Builder(BB);
4234 
4235   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4236 
4237   LLVMContext &Ctx = M->getContext();
4238   IntegerType *Int32 = Type::getInt32Ty(Ctx);
4239   AllocaInst *XVal = Builder.CreateAlloca(Int32);
4240   XVal->setName("x");
4241   AllocaInst *VVal = Builder.CreateAlloca(Int32);
4242   VVal->setName("v");
4243   AllocaInst *RVal = Builder.CreateAlloca(Int32);
4244   RVal->setName("r");
4245 
4246   StoreInst *Init =
4247       Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
4248 
4249   OpenMPIRBuilder::AtomicOpValue X = {XVal, Int32, true, false};
4250   OpenMPIRBuilder::AtomicOpValue V = {VVal, Int32, false, false};
4251   OpenMPIRBuilder::AtomicOpValue NoV = {nullptr, nullptr, false, false};
4252   OpenMPIRBuilder::AtomicOpValue R = {RVal, Int32, false, false};
4253   OpenMPIRBuilder::AtomicOpValue NoR = {nullptr, nullptr, false, false};
4254 
4255   AtomicOrdering AO = AtomicOrdering::Monotonic;
4256   ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
4257   ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
4258   OMPAtomicCompareOp OpMax = OMPAtomicCompareOp::MAX;
4259   OMPAtomicCompareOp OpEQ = OMPAtomicCompareOp::EQ;
4260 
4261   // { cond-update-stmt v = x; }
4262   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4263       Builder, X, V, NoR, Expr, D, AO, OpEQ, /* IsXBinopExpr */ true,
4264       /* IsPostfixUpdate */ false,
4265       /* IsFailOnly */ false));
4266   // { v = x; cond-update-stmt }
4267   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4268       Builder, X, V, NoR, Expr, D, AO, OpEQ, /* IsXBinopExpr */ true,
4269       /* IsPostfixUpdate */ true,
4270       /* IsFailOnly */ false));
4271   // if(x == e) { x = d; } else { v = x; }
4272   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4273       Builder, X, V, NoR, Expr, D, AO, OpEQ, /* IsXBinopExpr */ true,
4274       /* IsPostfixUpdate */ false,
4275       /* IsFailOnly */ true));
4276   // { r = x == e; if(r) { x = d; } }
4277   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4278       Builder, X, NoV, R, Expr, D, AO, OpEQ, /* IsXBinopExpr */ true,
4279       /* IsPostfixUpdate */ false,
4280       /* IsFailOnly */ false));
4281   // { r = x == e; if(r) { x = d; } else { v = x; } }
4282   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4283       Builder, X, V, R, Expr, D, AO, OpEQ, /* IsXBinopExpr */ true,
4284       /* IsPostfixUpdate */ false,
4285       /* IsFailOnly */ true));
4286 
4287   // { v = x; cond-update-stmt }
4288   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4289       Builder, X, V, NoR, Expr, nullptr, AO, OpMax, /* IsXBinopExpr */ true,
4290       /* IsPostfixUpdate */ true,
4291       /* IsFailOnly */ false));
4292   // { cond-update-stmt v = x; }
4293   Builder.restoreIP(OMPBuilder.createAtomicCompare(
4294       Builder, X, V, NoR, Expr, nullptr, AO, OpMax, /* IsXBinopExpr */ false,
4295       /* IsPostfixUpdate */ false,
4296       /* IsFailOnly */ false));
4297 
4298   BasicBlock *EntryBB = BB;
4299   EXPECT_EQ(EntryBB->getParent()->size(), 5U);
4300   BasicBlock *Cont1 = dyn_cast<BasicBlock>(EntryBB->getNextNode());
4301   EXPECT_NE(Cont1, nullptr);
4302   BasicBlock *Exit1 = dyn_cast<BasicBlock>(Cont1->getNextNode());
4303   EXPECT_NE(Exit1, nullptr);
4304   BasicBlock *Cont2 = dyn_cast<BasicBlock>(Exit1->getNextNode());
4305   EXPECT_NE(Cont2, nullptr);
4306   BasicBlock *Exit2 = dyn_cast<BasicBlock>(Cont2->getNextNode());
4307   EXPECT_NE(Exit2, nullptr);
4308 
4309   AtomicCmpXchgInst *CmpXchg1 =
4310       dyn_cast<AtomicCmpXchgInst>(Init->getNextNode());
4311   EXPECT_NE(CmpXchg1, nullptr);
4312   EXPECT_EQ(CmpXchg1->getPointerOperand(), XVal);
4313   EXPECT_EQ(CmpXchg1->getCompareOperand(), Expr);
4314   EXPECT_EQ(CmpXchg1->getNewValOperand(), D);
4315   ExtractValueInst *ExtVal1 =
4316       dyn_cast<ExtractValueInst>(CmpXchg1->getNextNode());
4317   EXPECT_NE(ExtVal1, nullptr);
4318   EXPECT_EQ(ExtVal1->getAggregateOperand(), CmpXchg1);
4319   EXPECT_EQ(ExtVal1->getIndices(), ArrayRef<unsigned int>(0U));
4320   ExtractValueInst *ExtVal2 =
4321       dyn_cast<ExtractValueInst>(ExtVal1->getNextNode());
4322   EXPECT_NE(ExtVal2, nullptr);
4323   EXPECT_EQ(ExtVal2->getAggregateOperand(), CmpXchg1);
4324   EXPECT_EQ(ExtVal2->getIndices(), ArrayRef<unsigned int>(1U));
4325   SelectInst *Sel1 = dyn_cast<SelectInst>(ExtVal2->getNextNode());
4326   EXPECT_NE(Sel1, nullptr);
4327   EXPECT_EQ(Sel1->getCondition(), ExtVal2);
4328   EXPECT_EQ(Sel1->getTrueValue(), Expr);
4329   EXPECT_EQ(Sel1->getFalseValue(), ExtVal1);
4330   StoreInst *Store1 = dyn_cast<StoreInst>(Sel1->getNextNode());
4331   EXPECT_NE(Store1, nullptr);
4332   EXPECT_EQ(Store1->getPointerOperand(), VVal);
4333   EXPECT_EQ(Store1->getValueOperand(), Sel1);
4334 
4335   AtomicCmpXchgInst *CmpXchg2 =
4336       dyn_cast<AtomicCmpXchgInst>(Store1->getNextNode());
4337   EXPECT_NE(CmpXchg2, nullptr);
4338   EXPECT_EQ(CmpXchg2->getPointerOperand(), XVal);
4339   EXPECT_EQ(CmpXchg2->getCompareOperand(), Expr);
4340   EXPECT_EQ(CmpXchg2->getNewValOperand(), D);
4341   ExtractValueInst *ExtVal3 =
4342       dyn_cast<ExtractValueInst>(CmpXchg2->getNextNode());
4343   EXPECT_NE(ExtVal3, nullptr);
4344   EXPECT_EQ(ExtVal3->getAggregateOperand(), CmpXchg2);
4345   EXPECT_EQ(ExtVal3->getIndices(), ArrayRef<unsigned int>(0U));
4346   StoreInst *Store2 = dyn_cast<StoreInst>(ExtVal3->getNextNode());
4347   EXPECT_NE(Store2, nullptr);
4348   EXPECT_EQ(Store2->getPointerOperand(), VVal);
4349   EXPECT_EQ(Store2->getValueOperand(), ExtVal3);
4350 
4351   AtomicCmpXchgInst *CmpXchg3 =
4352       dyn_cast<AtomicCmpXchgInst>(Store2->getNextNode());
4353   EXPECT_NE(CmpXchg3, nullptr);
4354   EXPECT_EQ(CmpXchg3->getPointerOperand(), XVal);
4355   EXPECT_EQ(CmpXchg3->getCompareOperand(), Expr);
4356   EXPECT_EQ(CmpXchg3->getNewValOperand(), D);
4357   ExtractValueInst *ExtVal4 =
4358       dyn_cast<ExtractValueInst>(CmpXchg3->getNextNode());
4359   EXPECT_NE(ExtVal4, nullptr);
4360   EXPECT_EQ(ExtVal4->getAggregateOperand(), CmpXchg3);
4361   EXPECT_EQ(ExtVal4->getIndices(), ArrayRef<unsigned int>(0U));
4362   ExtractValueInst *ExtVal5 =
4363       dyn_cast<ExtractValueInst>(ExtVal4->getNextNode());
4364   EXPECT_NE(ExtVal5, nullptr);
4365   EXPECT_EQ(ExtVal5->getAggregateOperand(), CmpXchg3);
4366   EXPECT_EQ(ExtVal5->getIndices(), ArrayRef<unsigned int>(1U));
4367   BranchInst *Br1 = dyn_cast<BranchInst>(ExtVal5->getNextNode());
4368   EXPECT_NE(Br1, nullptr);
4369   EXPECT_EQ(Br1->isConditional(), true);
4370   EXPECT_EQ(Br1->getCondition(), ExtVal5);
4371   EXPECT_EQ(Br1->getSuccessor(0), Exit1);
4372   EXPECT_EQ(Br1->getSuccessor(1), Cont1);
4373 
4374   StoreInst *Store3 = dyn_cast<StoreInst>(&Cont1->front());
4375   EXPECT_NE(Store3, nullptr);
4376   EXPECT_EQ(Store3->getPointerOperand(), VVal);
4377   EXPECT_EQ(Store3->getValueOperand(), ExtVal4);
4378   BranchInst *Br2 = dyn_cast<BranchInst>(Store3->getNextNode());
4379   EXPECT_NE(Br2, nullptr);
4380   EXPECT_EQ(Br2->isUnconditional(), true);
4381   EXPECT_EQ(Br2->getSuccessor(0), Exit1);
4382 
4383   AtomicCmpXchgInst *CmpXchg4 = dyn_cast<AtomicCmpXchgInst>(&Exit1->front());
4384   EXPECT_NE(CmpXchg4, nullptr);
4385   EXPECT_EQ(CmpXchg4->getPointerOperand(), XVal);
4386   EXPECT_EQ(CmpXchg4->getCompareOperand(), Expr);
4387   EXPECT_EQ(CmpXchg4->getNewValOperand(), D);
4388   ExtractValueInst *ExtVal6 =
4389       dyn_cast<ExtractValueInst>(CmpXchg4->getNextNode());
4390   EXPECT_NE(ExtVal6, nullptr);
4391   EXPECT_EQ(ExtVal6->getAggregateOperand(), CmpXchg4);
4392   EXPECT_EQ(ExtVal6->getIndices(), ArrayRef<unsigned int>(1U));
4393   ZExtInst *ZExt1 = dyn_cast<ZExtInst>(ExtVal6->getNextNode());
4394   EXPECT_NE(ZExt1, nullptr);
4395   EXPECT_EQ(ZExt1->getDestTy(), Int32);
4396   StoreInst *Store4 = dyn_cast<StoreInst>(ZExt1->getNextNode());
4397   EXPECT_NE(Store4, nullptr);
4398   EXPECT_EQ(Store4->getPointerOperand(), RVal);
4399   EXPECT_EQ(Store4->getValueOperand(), ZExt1);
4400 
4401   AtomicCmpXchgInst *CmpXchg5 =
4402       dyn_cast<AtomicCmpXchgInst>(Store4->getNextNode());
4403   EXPECT_NE(CmpXchg5, nullptr);
4404   EXPECT_EQ(CmpXchg5->getPointerOperand(), XVal);
4405   EXPECT_EQ(CmpXchg5->getCompareOperand(), Expr);
4406   EXPECT_EQ(CmpXchg5->getNewValOperand(), D);
4407   ExtractValueInst *ExtVal7 =
4408       dyn_cast<ExtractValueInst>(CmpXchg5->getNextNode());
4409   EXPECT_NE(ExtVal7, nullptr);
4410   EXPECT_EQ(ExtVal7->getAggregateOperand(), CmpXchg5);
4411   EXPECT_EQ(ExtVal7->getIndices(), ArrayRef<unsigned int>(0U));
4412   ExtractValueInst *ExtVal8 =
4413       dyn_cast<ExtractValueInst>(ExtVal7->getNextNode());
4414   EXPECT_NE(ExtVal8, nullptr);
4415   EXPECT_EQ(ExtVal8->getAggregateOperand(), CmpXchg5);
4416   EXPECT_EQ(ExtVal8->getIndices(), ArrayRef<unsigned int>(1U));
4417   BranchInst *Br3 = dyn_cast<BranchInst>(ExtVal8->getNextNode());
4418   EXPECT_NE(Br3, nullptr);
4419   EXPECT_EQ(Br3->isConditional(), true);
4420   EXPECT_EQ(Br3->getCondition(), ExtVal8);
4421   EXPECT_EQ(Br3->getSuccessor(0), Exit2);
4422   EXPECT_EQ(Br3->getSuccessor(1), Cont2);
4423 
4424   StoreInst *Store5 = dyn_cast<StoreInst>(&Cont2->front());
4425   EXPECT_NE(Store5, nullptr);
4426   EXPECT_EQ(Store5->getPointerOperand(), VVal);
4427   EXPECT_EQ(Store5->getValueOperand(), ExtVal7);
4428   BranchInst *Br4 = dyn_cast<BranchInst>(Store5->getNextNode());
4429   EXPECT_NE(Br4, nullptr);
4430   EXPECT_EQ(Br4->isUnconditional(), true);
4431   EXPECT_EQ(Br4->getSuccessor(0), Exit2);
4432 
4433   ExtractValueInst *ExtVal9 = dyn_cast<ExtractValueInst>(&Exit2->front());
4434   EXPECT_NE(ExtVal9, nullptr);
4435   EXPECT_EQ(ExtVal9->getAggregateOperand(), CmpXchg5);
4436   EXPECT_EQ(ExtVal9->getIndices(), ArrayRef<unsigned int>(1U));
4437   ZExtInst *ZExt2 = dyn_cast<ZExtInst>(ExtVal9->getNextNode());
4438   EXPECT_NE(ZExt2, nullptr);
4439   EXPECT_EQ(ZExt2->getDestTy(), Int32);
4440   StoreInst *Store6 = dyn_cast<StoreInst>(ZExt2->getNextNode());
4441   EXPECT_NE(Store6, nullptr);
4442   EXPECT_EQ(Store6->getPointerOperand(), RVal);
4443   EXPECT_EQ(Store6->getValueOperand(), ZExt2);
4444 
4445   AtomicRMWInst *ARWM1 = dyn_cast<AtomicRMWInst>(Store6->getNextNode());
4446   EXPECT_NE(ARWM1, nullptr);
4447   EXPECT_EQ(ARWM1->getPointerOperand(), XVal);
4448   EXPECT_EQ(ARWM1->getValOperand(), Expr);
4449   EXPECT_EQ(ARWM1->getOperation(), AtomicRMWInst::Min);
4450   StoreInst *Store7 = dyn_cast<StoreInst>(ARWM1->getNextNode());
4451   EXPECT_NE(Store7, nullptr);
4452   EXPECT_EQ(Store7->getPointerOperand(), VVal);
4453   EXPECT_EQ(Store7->getValueOperand(), ARWM1);
4454 
4455   AtomicRMWInst *ARWM2 = dyn_cast<AtomicRMWInst>(Store7->getNextNode());
4456   EXPECT_NE(ARWM2, nullptr);
4457   EXPECT_EQ(ARWM2->getPointerOperand(), XVal);
4458   EXPECT_EQ(ARWM2->getValOperand(), Expr);
4459   EXPECT_EQ(ARWM2->getOperation(), AtomicRMWInst::Max);
4460   CmpInst *Cmp1 = dyn_cast<CmpInst>(ARWM2->getNextNode());
4461   EXPECT_NE(Cmp1, nullptr);
4462   EXPECT_EQ(Cmp1->getPredicate(), CmpInst::ICMP_SGT);
4463   EXPECT_EQ(Cmp1->getOperand(0), ARWM2);
4464   EXPECT_EQ(Cmp1->getOperand(1), Expr);
4465   SelectInst *Sel2 = dyn_cast<SelectInst>(Cmp1->getNextNode());
4466   EXPECT_NE(Sel2, nullptr);
4467   EXPECT_EQ(Sel2->getCondition(), Cmp1);
4468   EXPECT_EQ(Sel2->getTrueValue(), Expr);
4469   EXPECT_EQ(Sel2->getFalseValue(), ARWM2);
4470   StoreInst *Store8 = dyn_cast<StoreInst>(Sel2->getNextNode());
4471   EXPECT_NE(Store8, nullptr);
4472   EXPECT_EQ(Store8->getPointerOperand(), VVal);
4473   EXPECT_EQ(Store8->getValueOperand(), Sel2);
4474 
4475   Builder.CreateRetVoid();
4476   OMPBuilder.finalize();
4477   EXPECT_FALSE(verifyModule(*M, &errs()));
4478 }
4479 
4480 TEST_F(OpenMPIRBuilderTest, CreateTeams) {
4481   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4482   OpenMPIRBuilder OMPBuilder(*M);
4483   OMPBuilder.Config.IsTargetDevice = false;
4484   OMPBuilder.initialize();
4485   F->setName("func");
4486   IRBuilder<> Builder(BB);
4487 
4488   AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
4489   AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
4490   Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load");
4491 
4492   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4493     Builder.restoreIP(AllocaIP);
4494     AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
4495                                                 "bodygen.alloca128");
4496 
4497     Builder.restoreIP(CodeGenIP);
4498     // Loading and storing captured pointer and values
4499     Builder.CreateStore(Val128, Local128);
4500     Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
4501                                       "bodygen.load32");
4502 
4503     LoadInst *PrivLoad128 = Builder.CreateLoad(
4504         Local128->getAllocatedType(), Local128, "bodygen.local.load128");
4505     Value *Cmp = Builder.CreateICmpNE(
4506         Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
4507     Instruction *ThenTerm, *ElseTerm;
4508     SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
4509                                   &ThenTerm, &ElseTerm);
4510     return Error::success();
4511   };
4512 
4513   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4514   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTeams(
4515       Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
4516       /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr);
4517   assert(AfterIP && "unexpected error");
4518   Builder.restoreIP(*AfterIP);
4519 
4520   OMPBuilder.finalize();
4521   Builder.CreateRetVoid();
4522 
4523   EXPECT_FALSE(verifyModule(*M, &errs()));
4524 
4525   CallInst *TeamsForkCall = dyn_cast<CallInst>(
4526       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
4527           ->user_back());
4528 
4529   // Verify the Ident argument
4530   GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
4531   ASSERT_NE(Ident, nullptr);
4532   EXPECT_TRUE(Ident->hasInitializer());
4533   Constant *Initializer = Ident->getInitializer();
4534   GlobalVariable *SrcStrGlob =
4535       cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
4536   ASSERT_NE(SrcStrGlob, nullptr);
4537   ConstantDataArray *SrcSrc =
4538       dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
4539   ASSERT_NE(SrcSrc, nullptr);
4540 
4541   // Verify the outlined function signature.
4542   Function *OutlinedFn =
4543       dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
4544   ASSERT_NE(OutlinedFn, nullptr);
4545   EXPECT_FALSE(OutlinedFn->isDeclaration());
4546   EXPECT_TRUE(OutlinedFn->arg_size() >= 3);
4547   EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
4548   EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
4549   EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
4550             Builder.getPtrTy()); // captured args
4551 
4552   // Check for TruncInst and ICmpInst in the outlined function.
4553   EXPECT_TRUE(any_of(instructions(OutlinedFn),
4554                      [](Instruction &inst) { return isa<TruncInst>(&inst); }));
4555   EXPECT_TRUE(any_of(instructions(OutlinedFn),
4556                      [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
4557 }
4558 
4559 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
4560   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4561   OpenMPIRBuilder OMPBuilder(*M);
4562   OMPBuilder.Config.IsTargetDevice = false;
4563   OMPBuilder.initialize();
4564   F->setName("func");
4565   IRBuilder<> &Builder = OMPBuilder.Builder;
4566   Builder.SetInsertPoint(BB);
4567 
4568   Function *FakeFunction =
4569       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4570                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4571 
4572   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4573     Builder.restoreIP(CodeGenIP);
4574     Builder.CreateCall(FakeFunction, {});
4575     return Error::success();
4576   };
4577 
4578   // `F` has an argument - an integer, so we use that as the thread limit.
4579   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTeams(
4580       /*=*/Builder, BodyGenCB, /*NumTeamsLower=*/nullptr,
4581       /*NumTeamsUpper=*/nullptr, /*ThreadLimit=*/F->arg_begin(),
4582       /*IfExpr=*/nullptr);
4583   assert(AfterIP && "unexpected error");
4584   Builder.restoreIP(*AfterIP);
4585 
4586   Builder.CreateRetVoid();
4587   OMPBuilder.finalize();
4588 
4589   ASSERT_FALSE(verifyModule(*M));
4590 
4591   CallInst *PushNumTeamsCallInst =
4592       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4593   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4594 
4595   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
4596   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
4597   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), &*F->arg_begin());
4598 
4599   // Verifying that the next instruction to execute is kmpc_fork_teams
4600   BranchInst *BrInst =
4601       dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
4602   ASSERT_NE(BrInst, nullptr);
4603   ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
4604   Instruction *NextInstruction =
4605       BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
4606   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
4607   ASSERT_NE(ForkTeamsCI, nullptr);
4608   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
4609             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
4610 }
4611 
4612 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
4613   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4614   OpenMPIRBuilder OMPBuilder(*M);
4615   OMPBuilder.Config.IsTargetDevice = false;
4616   OMPBuilder.initialize();
4617   F->setName("func");
4618   IRBuilder<> &Builder = OMPBuilder.Builder;
4619   Builder.SetInsertPoint(BB);
4620 
4621   Function *FakeFunction =
4622       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4623                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4624 
4625   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4626     Builder.restoreIP(CodeGenIP);
4627     Builder.CreateCall(FakeFunction, {});
4628     return Error::success();
4629   };
4630 
4631   // `F` already has an integer argument, so we use that as upper bound to
4632   // `num_teams`
4633   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4634       OMPBuilder.createTeams(Builder, BodyGenCB,
4635                              /*NumTeamsLower=*/nullptr,
4636                              /*NumTeamsUpper=*/F->arg_begin(),
4637                              /*ThreadLimit=*/nullptr,
4638                              /*IfExpr=*/nullptr);
4639   assert(AfterIP && "unexpected error");
4640   Builder.restoreIP(*AfterIP);
4641 
4642   Builder.CreateRetVoid();
4643   OMPBuilder.finalize();
4644 
4645   ASSERT_FALSE(verifyModule(*M));
4646 
4647   CallInst *PushNumTeamsCallInst =
4648       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4649   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4650 
4651   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
4652   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
4653   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));
4654 
4655   // Verifying that the next instruction to execute is kmpc_fork_teams
4656   BranchInst *BrInst =
4657       dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
4658   ASSERT_NE(BrInst, nullptr);
4659   ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
4660   Instruction *NextInstruction =
4661       BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
4662   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
4663   ASSERT_NE(ForkTeamsCI, nullptr);
4664   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
4665             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
4666 }
4667 
4668 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
4669   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4670   OpenMPIRBuilder OMPBuilder(*M);
4671   OMPBuilder.Config.IsTargetDevice = false;
4672   OMPBuilder.initialize();
4673   F->setName("func");
4674   IRBuilder<> &Builder = OMPBuilder.Builder;
4675   Builder.SetInsertPoint(BB);
4676 
4677   Function *FakeFunction =
4678       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4679                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4680 
4681   Value *NumTeamsLower =
4682       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
4683   Value *NumTeamsUpper =
4684       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
4685 
4686   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4687     Builder.restoreIP(CodeGenIP);
4688     Builder.CreateCall(FakeFunction, {});
4689     return Error::success();
4690   };
4691 
4692   // `F` already has an integer argument, so we use that as upper bound to
4693   // `num_teams`
4694   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4695       OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
4696                              /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr);
4697   assert(AfterIP && "unexpected error");
4698   Builder.restoreIP(*AfterIP);
4699 
4700   Builder.CreateRetVoid();
4701   OMPBuilder.finalize();
4702 
4703   ASSERT_FALSE(verifyModule(*M));
4704 
4705   CallInst *PushNumTeamsCallInst =
4706       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4707   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4708 
4709   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
4710   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
4711   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));
4712 
4713   // Verifying that the next instruction to execute is kmpc_fork_teams
4714   BranchInst *BrInst =
4715       dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
4716   ASSERT_NE(BrInst, nullptr);
4717   ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
4718   Instruction *NextInstruction =
4719       BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
4720   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
4721   ASSERT_NE(ForkTeamsCI, nullptr);
4722   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
4723             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
4724 }
4725 
4726 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
4727   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4728   OpenMPIRBuilder OMPBuilder(*M);
4729   OMPBuilder.Config.IsTargetDevice = false;
4730   OMPBuilder.initialize();
4731   F->setName("func");
4732   IRBuilder<> &Builder = OMPBuilder.Builder;
4733   Builder.SetInsertPoint(BB);
4734 
4735   BasicBlock *CodegenBB = splitBB(Builder, true);
4736   Builder.SetInsertPoint(CodegenBB);
4737 
4738   // Generate values for `num_teams` and `thread_limit` using the first argument
4739   // of the testing function.
4740   Value *NumTeamsLower =
4741       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
4742   Value *NumTeamsUpper =
4743       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
4744   Value *ThreadLimit =
4745       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");
4746 
4747   Function *FakeFunction =
4748       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4749                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4750 
4751   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4752     Builder.restoreIP(CodeGenIP);
4753     Builder.CreateCall(FakeFunction, {});
4754     return Error::success();
4755   };
4756 
4757   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4758   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTeams(
4759       Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr);
4760   assert(AfterIP && "unexpected error");
4761   Builder.restoreIP(*AfterIP);
4762 
4763   Builder.CreateRetVoid();
4764   OMPBuilder.finalize();
4765 
4766   ASSERT_FALSE(verifyModule(*M));
4767 
4768   CallInst *PushNumTeamsCallInst =
4769       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4770   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4771 
4772   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
4773   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
4774   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), ThreadLimit);
4775 
4776   // Verifying that the next instruction to execute is kmpc_fork_teams
4777   BranchInst *BrInst =
4778       dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
4779   ASSERT_NE(BrInst, nullptr);
4780   ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
4781   Instruction *NextInstruction =
4782       BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
4783   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
4784   ASSERT_NE(ForkTeamsCI, nullptr);
4785   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
4786             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
4787 }
4788 
4789 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
4790   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4791   OpenMPIRBuilder OMPBuilder(*M);
4792   OMPBuilder.Config.IsTargetDevice = false;
4793   OMPBuilder.initialize();
4794   F->setName("func");
4795   IRBuilder<> &Builder = OMPBuilder.Builder;
4796   Builder.SetInsertPoint(BB);
4797 
4798   Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
4799                                      Builder.CreateAlloca(Builder.getInt1Ty()));
4800 
4801   Function *FakeFunction =
4802       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4803                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4804 
4805   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4806     Builder.restoreIP(CodeGenIP);
4807     Builder.CreateCall(FakeFunction, {});
4808     return Error::success();
4809   };
4810 
4811   // `F` already has an integer argument, so we use that as upper bound to
4812   // `num_teams`
4813   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTeams(
4814       Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
4815       /*ThreadLimit=*/nullptr, IfExpr);
4816   assert(AfterIP && "unexpected error");
4817   Builder.restoreIP(*AfterIP);
4818 
4819   Builder.CreateRetVoid();
4820   OMPBuilder.finalize();
4821 
4822   ASSERT_FALSE(verifyModule(*M));
4823 
4824   CallInst *PushNumTeamsCallInst =
4825       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4826   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4827   Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
4828   Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
4829   Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);
4830 
4831   // Check the lower_bound
4832   ASSERT_NE(NumTeamsLower, nullptr);
4833   SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
4834   ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
4835   EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
4836   EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
4837   EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
4838 
4839   // Check the upper_bound
4840   ASSERT_NE(NumTeamsUpper, nullptr);
4841   SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
4842   ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
4843   EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
4844   EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
4845   EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
4846 
4847   // Check thread_limit
4848   EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
4849 }
4850 
4851 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
4852   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4853   OpenMPIRBuilder OMPBuilder(*M);
4854   OMPBuilder.Config.IsTargetDevice = false;
4855   OMPBuilder.initialize();
4856   F->setName("func");
4857   IRBuilder<> &Builder = OMPBuilder.Builder;
4858   Builder.SetInsertPoint(BB);
4859 
4860   Value *IfExpr = Builder.CreateLoad(
4861       Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
4862   Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
4863   Value *NumTeamsUpper =
4864       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
4865   Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));
4866 
4867   Function *FakeFunction =
4868       Function::Create(FunctionType::get(Builder.getVoidTy(), false),
4869                        GlobalValue::ExternalLinkage, "fakeFunction", M.get());
4870 
4871   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4872     Builder.restoreIP(CodeGenIP);
4873     Builder.CreateCall(FakeFunction, {});
4874     return Error::success();
4875   };
4876 
4877   // `F` already has an integer argument, so we use that as upper bound to
4878   // `num_teams`
4879   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTeams(
4880       Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, IfExpr);
4881   assert(AfterIP && "unexpected error");
4882   Builder.restoreIP(*AfterIP);
4883 
4884   Builder.CreateRetVoid();
4885   OMPBuilder.finalize();
4886 
4887   ASSERT_FALSE(verifyModule(*M));
4888 
4889   CallInst *PushNumTeamsCallInst =
4890       findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4891   ASSERT_NE(PushNumTeamsCallInst, nullptr);
4892   Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
4893   Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
4894   Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);
4895 
4896   // Get the boolean conversion of if expression
4897   ASSERT_EQ(IfExpr->getNumUses(), 1U);
4898   User *IfExprInst = IfExpr->user_back();
4899   ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
4900   ASSERT_NE(IfExprCmpInst, nullptr);
4901   EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
4902   EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
4903   EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));
4904 
4905   // Check the lower_bound
4906   ASSERT_NE(NumTeamsLowerArg, nullptr);
4907   SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
4908   ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
4909   EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
4910   EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
4911   EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
4912 
4913   // Check the upper_bound
4914   ASSERT_NE(NumTeamsUpperArg, nullptr);
4915   SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
4916   ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
4917   EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
4918   EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
4919   EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
4920 
4921   // Check thread_limit
4922   EXPECT_EQ(ThreadLimitArg, ThreadLimit);
4923 }
4924 
4925 /// Returns the single instruction of InstTy type in BB that uses the value V.
4926 /// If there is more than one such instruction, returns null.
4927 template <typename InstTy>
4928 static InstTy *findSingleUserInBlock(Value *V, BasicBlock *BB) {
4929   InstTy *Result = nullptr;
4930   for (User *U : V->users()) {
4931     auto *Inst = dyn_cast<InstTy>(U);
4932     if (!Inst || Inst->getParent() != BB)
4933       continue;
4934     if (Result) {
4935       if (auto *SI = dyn_cast<StoreInst>(Inst)) {
4936         if (V == SI->getValueOperand())
4937           continue;
4938       } else {
4939         return nullptr;
4940       }
4941     }
4942     Result = Inst;
4943   }
4944   return Result;
4945 }
4946 
4947 /// Returns true if BB contains a simple binary reduction that loads a value
4948 /// from Accum, performs some binary operation with it, and stores it back to
4949 /// Accum.
4950 static bool isSimpleBinaryReduction(Value *Accum, BasicBlock *BB,
4951                                     Instruction::BinaryOps *OpCode = nullptr) {
4952   StoreInst *Store = findSingleUserInBlock<StoreInst>(Accum, BB);
4953   if (!Store)
4954     return false;
4955   auto *Stored = dyn_cast<BinaryOperator>(Store->getOperand(0));
4956   if (!Stored)
4957     return false;
4958   if (OpCode && *OpCode != Stored->getOpcode())
4959     return false;
4960   auto *Load = dyn_cast<LoadInst>(Stored->getOperand(0));
4961   return Load && Load->getOperand(0) == Accum;
4962 }
4963 
4964 /// Returns true if BB contains a binary reduction that reduces V using a binary
4965 /// operator into an accumulator that is a function argument.
4966 static bool isValueReducedToFuncArg(Value *V, BasicBlock *BB) {
4967   auto *ReductionOp = findSingleUserInBlock<BinaryOperator>(V, BB);
4968   if (!ReductionOp)
4969     return false;
4970 
4971   auto *GlobalLoad = dyn_cast<LoadInst>(ReductionOp->getOperand(0));
4972   if (!GlobalLoad)
4973     return false;
4974 
4975   auto *Store = findSingleUserInBlock<StoreInst>(ReductionOp, BB);
4976   if (!Store)
4977     return false;
4978 
4979   return Store->getPointerOperand() == GlobalLoad->getPointerOperand() &&
4980          isa<Argument>(findAggregateFromValue(GlobalLoad->getPointerOperand()));
4981 }
4982 
4983 /// Finds among users of Ptr a pair of GEP instructions with indices [0, 0] and
4984 /// [0, 1], respectively, and assigns results of these instructions to Zero and
4985 /// One. Returns true on success, false on failure or if such instructions are
4986 /// not unique among the users of Ptr.
4987 static bool findGEPZeroOne(Value *Ptr, Value *&Zero, Value *&One) {
4988   Zero = nullptr;
4989   One = nullptr;
4990   for (User *U : Ptr->users()) {
4991     if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
4992       if (GEP->getNumIndices() != 2)
4993         continue;
4994       auto *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1));
4995       auto *SecondIdx = dyn_cast<ConstantInt>(GEP->getOperand(2));
4996       EXPECT_NE(FirstIdx, nullptr);
4997       EXPECT_NE(SecondIdx, nullptr);
4998 
4999       EXPECT_TRUE(FirstIdx->isZero());
5000       if (SecondIdx->isZero()) {
5001         if (Zero)
5002           return false;
5003         Zero = GEP;
5004       } else if (SecondIdx->isOne()) {
5005         if (One)
5006           return false;
5007         One = GEP;
5008       } else {
5009         return false;
5010       }
5011     }
5012   }
5013   return Zero != nullptr && One != nullptr;
5014 }
5015 
5016 static OpenMPIRBuilder::InsertPointTy
5017 sumReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
5018              Value *&Result) {
5019   IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
5020   Result = Builder.CreateFAdd(LHS, RHS, "red.add");
5021   return Builder.saveIP();
5022 }
5023 
5024 static OpenMPIRBuilder::InsertPointTy
5025 sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
5026                    Value *RHS) {
5027   IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
5028   Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
5029   Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, LHS, Partial, std::nullopt,
5030                           AtomicOrdering::Monotonic);
5031   return Builder.saveIP();
5032 }
5033 
5034 static OpenMPIRBuilder::InsertPointTy
5035 xorReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
5036              Value *&Result) {
5037   IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
5038   Result = Builder.CreateXor(LHS, RHS, "red.xor");
5039   return Builder.saveIP();
5040 }
5041 
5042 static OpenMPIRBuilder::InsertPointTy
5043 xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
5044                    Value *RHS) {
5045   IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
5046   Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
5047   Builder.CreateAtomicRMW(AtomicRMWInst::Xor, LHS, Partial, std::nullopt,
5048                           AtomicOrdering::Monotonic);
5049   return Builder.saveIP();
5050 }
5051 
5052 TEST_F(OpenMPIRBuilderTest, CreateReductions) {
5053   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5054   OpenMPIRBuilder OMPBuilder(*M);
5055   OMPBuilder.Config.IsTargetDevice = false;
5056   OMPBuilder.initialize();
5057   F->setName("func");
5058   IRBuilder<> Builder(BB);
5059 
5060   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
5061   Builder.CreateBr(EnterBB);
5062   Builder.SetInsertPoint(EnterBB);
5063   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5064 
5065   // Create variables to be reduced.
5066   InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
5067                               F->getEntryBlock().getFirstInsertionPt());
5068   Type *SumType = Builder.getFloatTy();
5069   Type *XorType = Builder.getInt32Ty();
5070   Value *SumReduced;
5071   Value *XorReduced;
5072   {
5073     IRBuilderBase::InsertPointGuard Guard(Builder);
5074     Builder.restoreIP(OuterAllocaIP);
5075     SumReduced = Builder.CreateAlloca(SumType);
5076     XorReduced = Builder.CreateAlloca(XorType);
5077   }
5078 
5079   // Store initial values of reductions into global variables.
5080   Builder.CreateStore(ConstantFP::get(Builder.getFloatTy(), 0.0), SumReduced);
5081   Builder.CreateStore(Builder.getInt32(1), XorReduced);
5082 
5083   // The loop body computes two reductions:
5084   //   sum of (float) thread-id;
5085   //   xor of thread-id;
5086   // and store the result in global variables.
5087   InsertPointTy BodyIP, BodyAllocaIP;
5088   auto BodyGenCB = [&](InsertPointTy InnerAllocaIP, InsertPointTy CodeGenIP) {
5089     IRBuilderBase::InsertPointGuard Guard(Builder);
5090     Builder.restoreIP(CodeGenIP);
5091 
5092     uint32_t StrSize;
5093     Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc, StrSize);
5094     Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr, StrSize);
5095     Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
5096     Value *SumLocal =
5097         Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
5098     Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
5099     Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
5100     Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
5101     Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
5102     Builder.CreateStore(Sum, SumReduced);
5103     Builder.CreateStore(Xor, XorReduced);
5104 
5105     BodyIP = Builder.saveIP();
5106     BodyAllocaIP = InnerAllocaIP;
5107     return Error::success();
5108   };
5109 
5110   // Privatization for reduction creates local copies of reduction variables and
5111   // initializes them to reduction-neutral values.
5112   Value *SumPrivatized;
5113   Value *XorPrivatized;
5114   auto PrivCB = [&](InsertPointTy InnerAllocaIP, InsertPointTy CodeGenIP,
5115                     Value &Original, Value &Inner, Value *&ReplVal) {
5116     IRBuilderBase::InsertPointGuard Guard(Builder);
5117     Builder.restoreIP(InnerAllocaIP);
5118     if (&Original == SumReduced) {
5119       SumPrivatized = Builder.CreateAlloca(Builder.getFloatTy());
5120       ReplVal = SumPrivatized;
5121     } else if (&Original == XorReduced) {
5122       XorPrivatized = Builder.CreateAlloca(Builder.getInt32Ty());
5123       ReplVal = XorPrivatized;
5124     } else {
5125       ReplVal = &Inner;
5126       return CodeGenIP;
5127     }
5128 
5129     Builder.restoreIP(CodeGenIP);
5130     if (&Original == SumReduced)
5131       Builder.CreateStore(ConstantFP::get(Builder.getFloatTy(), 0.0),
5132                           SumPrivatized);
5133     else if (&Original == XorReduced)
5134       Builder.CreateStore(Builder.getInt32(0), XorPrivatized);
5135 
5136     return Builder.saveIP();
5137   };
5138 
5139   // Do nothing in finalization.
5140   auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
5141 
5142   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5143       OMPBuilder.createParallel(Loc, OuterAllocaIP, BodyGenCB, PrivCB, FiniCB,
5144                                 /* IfCondition */ nullptr,
5145                                 /* NumThreads */ nullptr, OMP_PROC_BIND_default,
5146                                 /* IsCancellable */ false);
5147   assert(AfterIP && "unexpected error");
5148   Builder.restoreIP(*AfterIP);
5149 
5150   OpenMPIRBuilder::ReductionInfo ReductionInfos[] = {
5151       {SumType, SumReduced, SumPrivatized,
5152        /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction,
5153        /*ReductionGenClang=*/nullptr, sumAtomicReduction},
5154       {XorType, XorReduced, XorPrivatized,
5155        /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, xorReduction,
5156        /*ReductionGenClang=*/nullptr, xorAtomicReduction}};
5157   OMPBuilder.Config.setIsGPU(false);
5158 
5159   bool ReduceVariableByRef[] = {false, false};
5160 
5161   OpenMPIRBuilder::InsertPointOrErrorTy ReductionsIP =
5162       OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos,
5163                                   ReduceVariableByRef);
5164   assert(ReductionsIP && "unexpected error");
5165 
5166   Builder.restoreIP(*AfterIP);
5167   Builder.CreateRetVoid();
5168 
5169   OMPBuilder.finalize(F);
5170 
5171   // The IR must be valid.
5172   EXPECT_FALSE(verifyModule(*M));
5173 
5174   // Outlining must have happened.
5175   SmallVector<CallInst *> ForkCalls;
5176   findCalls(F, omp::RuntimeFunction::OMPRTL___kmpc_fork_call, OMPBuilder,
5177             ForkCalls);
5178   ASSERT_EQ(ForkCalls.size(), 1u);
5179   Value *CalleeVal = ForkCalls[0]->getOperand(2);
5180   Function *Outlined = dyn_cast<Function>(CalleeVal);
5181   EXPECT_NE(Outlined, nullptr);
5182 
5183   // Check that the lock variable was created with the expected name.
5184   GlobalVariable *LockVar =
5185       M->getGlobalVariable(".gomp_critical_user_.reduction.var");
5186   EXPECT_NE(LockVar, nullptr);
5187 
5188   // Find the allocation of a local array that will be used to call the runtime
5189   // reduciton function.
5190   BasicBlock &AllocBlock = Outlined->getEntryBlock();
5191   Value *LocalArray = nullptr;
5192   for (Instruction &I : AllocBlock) {
5193     if (AllocaInst *Alloc = dyn_cast<AllocaInst>(&I)) {
5194       if (!Alloc->getAllocatedType()->isArrayTy() ||
5195           !Alloc->getAllocatedType()->getArrayElementType()->isPointerTy())
5196         continue;
5197       LocalArray = Alloc;
5198       break;
5199     }
5200   }
5201   ASSERT_NE(LocalArray, nullptr);
5202 
5203   // Find the call to the runtime reduction function.
5204   BasicBlock *BB = AllocBlock.getUniqueSuccessor();
5205   Value *LocalArrayPtr = nullptr;
5206   Value *ReductionFnVal = nullptr;
5207   Value *SwitchArg = nullptr;
5208   for (Instruction &I : *BB) {
5209     if (CallInst *Call = dyn_cast<CallInst>(&I)) {
5210       if (Call->getCalledFunction() !=
5211           OMPBuilder.getOrCreateRuntimeFunctionPtr(
5212               RuntimeFunction::OMPRTL___kmpc_reduce))
5213         continue;
5214       LocalArrayPtr = Call->getOperand(4);
5215       ReductionFnVal = Call->getOperand(5);
5216       SwitchArg = Call;
5217       break;
5218     }
5219   }
5220 
5221   // Check that the local array is passed to the function.
5222   ASSERT_NE(LocalArrayPtr, nullptr);
5223   EXPECT_EQ(LocalArrayPtr, LocalArray);
5224 
5225   // Find the GEP instructions preceding stores to the local array.
5226   Value *FirstArrayElemPtr = nullptr;
5227   Value *SecondArrayElemPtr = nullptr;
5228   EXPECT_EQ(LocalArray->getNumUses(), 3u);
5229   ASSERT_TRUE(
5230       findGEPZeroOne(LocalArray, FirstArrayElemPtr, SecondArrayElemPtr));
5231 
5232   // Check that the values stored into the local array are privatized reduction
5233   // variables.
5234   auto *FirstPrivatized = dyn_cast_or_null<AllocaInst>(
5235       findStoredValue<GetElementPtrInst>(FirstArrayElemPtr));
5236   auto *SecondPrivatized = dyn_cast_or_null<AllocaInst>(
5237       findStoredValue<GetElementPtrInst>(SecondArrayElemPtr));
5238   ASSERT_NE(FirstPrivatized, nullptr);
5239   ASSERT_NE(SecondPrivatized, nullptr);
5240   ASSERT_TRUE(isa<Instruction>(FirstArrayElemPtr));
5241   EXPECT_TRUE(isSimpleBinaryReduction(
5242       FirstPrivatized, cast<Instruction>(FirstArrayElemPtr)->getParent()));
5243   EXPECT_TRUE(isSimpleBinaryReduction(
5244       SecondPrivatized, cast<Instruction>(FirstArrayElemPtr)->getParent()));
5245 
5246   // Check that the result of the runtime reduction call is used for further
5247   // dispatch.
5248   ASSERT_EQ(SwitchArg->getNumUses(), 1u);
5249   SwitchInst *Switch = dyn_cast<SwitchInst>(*SwitchArg->user_begin());
5250   ASSERT_NE(Switch, nullptr);
5251   EXPECT_EQ(Switch->getNumSuccessors(), 3u);
5252   BasicBlock *NonAtomicBB = Switch->case_begin()->getCaseSuccessor();
5253   BasicBlock *AtomicBB = std::next(Switch->case_begin())->getCaseSuccessor();
5254 
5255   // Non-atomic block contains reductions to the global reduction variable,
5256   // which is passed into the outlined function as an argument.
5257   Value *FirstLoad =
5258       findSingleUserInBlock<LoadInst>(FirstPrivatized, NonAtomicBB);
5259   Value *SecondLoad =
5260       findSingleUserInBlock<LoadInst>(SecondPrivatized, NonAtomicBB);
5261   EXPECT_TRUE(isValueReducedToFuncArg(FirstLoad, NonAtomicBB));
5262   EXPECT_TRUE(isValueReducedToFuncArg(SecondLoad, NonAtomicBB));
5263 
5264   // Atomic block also constains reductions to the global reduction variable.
5265   FirstLoad = findSingleUserInBlock<LoadInst>(FirstPrivatized, AtomicBB);
5266   SecondLoad = findSingleUserInBlock<LoadInst>(SecondPrivatized, AtomicBB);
5267   auto *FirstAtomic = findSingleUserInBlock<AtomicRMWInst>(FirstLoad, AtomicBB);
5268   auto *SecondAtomic =
5269       findSingleUserInBlock<AtomicRMWInst>(SecondLoad, AtomicBB);
5270   ASSERT_NE(FirstAtomic, nullptr);
5271   Value *AtomicStorePointer = FirstAtomic->getPointerOperand();
5272   EXPECT_TRUE(isa<Argument>(findAggregateFromValue(AtomicStorePointer)));
5273   ASSERT_NE(SecondAtomic, nullptr);
5274   AtomicStorePointer = SecondAtomic->getPointerOperand();
5275   EXPECT_TRUE(isa<Argument>(findAggregateFromValue(AtomicStorePointer)));
5276 
5277   // Check that the separate reduction function also performs (non-atomic)
5278   // reductions after extracting reduction variables from its arguments.
5279   Function *ReductionFn = cast<Function>(ReductionFnVal);
5280   BasicBlock *FnReductionBB = &ReductionFn->getEntryBlock();
5281   Value *FirstLHSPtr;
5282   Value *SecondLHSPtr;
5283   ASSERT_TRUE(
5284       findGEPZeroOne(ReductionFn->getArg(0), FirstLHSPtr, SecondLHSPtr));
5285   Value *Opaque = findSingleUserInBlock<LoadInst>(FirstLHSPtr, FnReductionBB);
5286   ASSERT_NE(Opaque, nullptr);
5287   EXPECT_TRUE(isSimpleBinaryReduction(Opaque, FnReductionBB));
5288   Opaque = findSingleUserInBlock<LoadInst>(SecondLHSPtr, FnReductionBB);
5289   ASSERT_NE(Opaque, nullptr);
5290   EXPECT_TRUE(isSimpleBinaryReduction(Opaque, FnReductionBB));
5291 
5292   Value *FirstRHS;
5293   Value *SecondRHS;
5294   EXPECT_TRUE(findGEPZeroOne(ReductionFn->getArg(1), FirstRHS, SecondRHS));
5295 }
5296 
5297 TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
5298   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5299   OpenMPIRBuilder OMPBuilder(*M);
5300   OMPBuilder.Config.IsTargetDevice = false;
5301   OMPBuilder.initialize();
5302   F->setName("func");
5303   IRBuilder<> Builder(BB);
5304 
5305   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "parallel.enter", F);
5306   Builder.CreateBr(EnterBB);
5307   Builder.SetInsertPoint(EnterBB);
5308   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5309 
5310   // Create variables to be reduced.
5311   InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
5312                               F->getEntryBlock().getFirstInsertionPt());
5313   Type *SumType = Builder.getFloatTy();
5314   Type *XorType = Builder.getInt32Ty();
5315   Value *SumReduced;
5316   Value *XorReduced;
5317   {
5318     IRBuilderBase::InsertPointGuard Guard(Builder);
5319     Builder.restoreIP(OuterAllocaIP);
5320     SumReduced = Builder.CreateAlloca(SumType);
5321     XorReduced = Builder.CreateAlloca(XorType);
5322   }
5323 
5324   // Store initial values of reductions into global variables.
5325   Builder.CreateStore(ConstantFP::get(Builder.getFloatTy(), 0.0), SumReduced);
5326   Builder.CreateStore(Builder.getInt32(1), XorReduced);
5327 
5328   InsertPointTy FirstBodyIP, FirstBodyAllocaIP;
5329   auto FirstBodyGenCB = [&](InsertPointTy InnerAllocaIP,
5330                             InsertPointTy CodeGenIP) {
5331     IRBuilderBase::InsertPointGuard Guard(Builder);
5332     Builder.restoreIP(CodeGenIP);
5333 
5334     uint32_t StrSize;
5335     Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc, StrSize);
5336     Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr, StrSize);
5337     Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
5338     Value *SumLocal =
5339         Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
5340     Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
5341     Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
5342     Builder.CreateStore(Sum, SumReduced);
5343 
5344     FirstBodyIP = Builder.saveIP();
5345     FirstBodyAllocaIP = InnerAllocaIP;
5346     return Error::success();
5347   };
5348 
5349   InsertPointTy SecondBodyIP, SecondBodyAllocaIP;
5350   auto SecondBodyGenCB = [&](InsertPointTy InnerAllocaIP,
5351                              InsertPointTy CodeGenIP) {
5352     IRBuilderBase::InsertPointGuard Guard(Builder);
5353     Builder.restoreIP(CodeGenIP);
5354 
5355     uint32_t StrSize;
5356     Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc, StrSize);
5357     Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr, StrSize);
5358     Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
5359     Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
5360     Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
5361     Builder.CreateStore(Xor, XorReduced);
5362 
5363     SecondBodyIP = Builder.saveIP();
5364     SecondBodyAllocaIP = InnerAllocaIP;
5365     return Error::success();
5366   };
5367 
5368   // Privatization for reduction creates local copies of reduction variables and
5369   // initializes them to reduction-neutral values. The same privatization
5370   // callback is used for both loops, with dispatch based on the value being
5371   // privatized.
5372   Value *SumPrivatized;
5373   Value *XorPrivatized;
5374   auto PrivCB = [&](InsertPointTy InnerAllocaIP, InsertPointTy CodeGenIP,
5375                     Value &Original, Value &Inner, Value *&ReplVal) {
5376     IRBuilderBase::InsertPointGuard Guard(Builder);
5377     Builder.restoreIP(InnerAllocaIP);
5378     if (&Original == SumReduced) {
5379       SumPrivatized = Builder.CreateAlloca(Builder.getFloatTy());
5380       ReplVal = SumPrivatized;
5381     } else if (&Original == XorReduced) {
5382       XorPrivatized = Builder.CreateAlloca(Builder.getInt32Ty());
5383       ReplVal = XorPrivatized;
5384     } else {
5385       ReplVal = &Inner;
5386       return CodeGenIP;
5387     }
5388 
5389     Builder.restoreIP(CodeGenIP);
5390     if (&Original == SumReduced)
5391       Builder.CreateStore(ConstantFP::get(Builder.getFloatTy(), 0.0),
5392                           SumPrivatized);
5393     else if (&Original == XorReduced)
5394       Builder.CreateStore(Builder.getInt32(0), XorPrivatized);
5395 
5396     return Builder.saveIP();
5397   };
5398 
5399   // Do nothing in finalization.
5400   auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
5401 
5402   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP1 =
5403       OMPBuilder.createParallel(Loc, OuterAllocaIP, FirstBodyGenCB, PrivCB,
5404                                 FiniCB, /* IfCondition */ nullptr,
5405                                 /* NumThreads */ nullptr, OMP_PROC_BIND_default,
5406                                 /* IsCancellable */ false);
5407   assert(AfterIP1 && "unexpected error");
5408   Builder.restoreIP(*AfterIP1);
5409   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP2 = OMPBuilder.createParallel(
5410       {Builder.saveIP(), DL}, OuterAllocaIP, SecondBodyGenCB, PrivCB, FiniCB,
5411       /* IfCondition */ nullptr,
5412       /* NumThreads */ nullptr, OMP_PROC_BIND_default,
5413       /* IsCancellable */ false);
5414   assert(AfterIP2 && "unexpected error");
5415   Builder.restoreIP(*AfterIP2);
5416 
5417   OMPBuilder.Config.setIsGPU(false);
5418   bool ReduceVariableByRef[] = {false};
5419 
5420   OpenMPIRBuilder::InsertPointOrErrorTy ReductionsIP1 =
5421       OMPBuilder.createReductions(
5422           FirstBodyIP, FirstBodyAllocaIP,
5423           {{SumType, SumReduced, SumPrivatized,
5424             /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction,
5425             /*ReductionGenClang=*/nullptr, sumAtomicReduction}},
5426           ReduceVariableByRef);
5427   assert(ReductionsIP1 && "unexpected error");
5428   OpenMPIRBuilder::InsertPointOrErrorTy ReductionsIP2 =
5429       OMPBuilder.createReductions(
5430           SecondBodyIP, SecondBodyAllocaIP,
5431           {{XorType, XorReduced, XorPrivatized,
5432             /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, xorReduction,
5433             /*ReductionGenClang=*/nullptr, xorAtomicReduction}},
5434           ReduceVariableByRef);
5435   assert(ReductionsIP2 && "unexpected error");
5436 
5437   Builder.restoreIP(*AfterIP2);
5438   Builder.CreateRetVoid();
5439 
5440   OMPBuilder.finalize(F);
5441 
5442   // The IR must be valid.
5443   EXPECT_FALSE(verifyModule(*M));
5444 
5445   // Two different outlined functions must have been created.
5446   SmallVector<CallInst *> ForkCalls;
5447   findCalls(F, omp::RuntimeFunction::OMPRTL___kmpc_fork_call, OMPBuilder,
5448             ForkCalls);
5449   ASSERT_EQ(ForkCalls.size(), 2u);
5450   Value *CalleeVal = ForkCalls[0]->getOperand(2);
5451   Function *FirstCallee = cast<Function>(CalleeVal);
5452   CalleeVal = ForkCalls[1]->getOperand(2);
5453   Function *SecondCallee = cast<Function>(CalleeVal);
5454   EXPECT_NE(FirstCallee, SecondCallee);
5455 
5456   // Two different reduction functions must have been created.
5457   SmallVector<CallInst *> ReduceCalls;
5458   findCalls(FirstCallee, omp::RuntimeFunction::OMPRTL___kmpc_reduce, OMPBuilder,
5459             ReduceCalls);
5460   ASSERT_EQ(ReduceCalls.size(), 1u);
5461   auto *AddReduction = cast<Function>(ReduceCalls[0]->getOperand(5));
5462   ReduceCalls.clear();
5463   findCalls(SecondCallee, omp::RuntimeFunction::OMPRTL___kmpc_reduce,
5464             OMPBuilder, ReduceCalls);
5465   auto *XorReduction = cast<Function>(ReduceCalls[0]->getOperand(5));
5466   EXPECT_NE(AddReduction, XorReduction);
5467 
5468   // Each reduction function does its own kind of reduction.
5469   BasicBlock *FnReductionBB = &AddReduction->getEntryBlock();
5470   Value *FirstLHSPtr = findSingleUserInBlock<GetElementPtrInst>(
5471       AddReduction->getArg(0), FnReductionBB);
5472   ASSERT_NE(FirstLHSPtr, nullptr);
5473   Value *Opaque = findSingleUserInBlock<LoadInst>(FirstLHSPtr, FnReductionBB);
5474   ASSERT_NE(Opaque, nullptr);
5475   Instruction::BinaryOps Opcode = Instruction::FAdd;
5476   EXPECT_TRUE(isSimpleBinaryReduction(Opaque, FnReductionBB, &Opcode));
5477 
5478   FnReductionBB = &XorReduction->getEntryBlock();
5479   Value *SecondLHSPtr = findSingleUserInBlock<GetElementPtrInst>(
5480       XorReduction->getArg(0), FnReductionBB);
5481   ASSERT_NE(FirstLHSPtr, nullptr);
5482   Opaque = findSingleUserInBlock<LoadInst>(SecondLHSPtr, FnReductionBB);
5483   ASSERT_NE(Opaque, nullptr);
5484   Opcode = Instruction::Xor;
5485   EXPECT_TRUE(isSimpleBinaryReduction(Opaque, FnReductionBB, &Opcode));
5486 }
5487 
5488 TEST_F(OpenMPIRBuilderTest, CreateSectionsSimple) {
5489   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5490   using BodyGenCallbackTy = llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
5491   OpenMPIRBuilder OMPBuilder(*M);
5492   OMPBuilder.initialize();
5493   F->setName("func");
5494   IRBuilder<> Builder(BB);
5495 
5496   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F);
5497   Builder.CreateBr(EnterBB);
5498   Builder.SetInsertPoint(EnterBB);
5499   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5500 
5501   llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;
5502   llvm::SmallVector<BasicBlock *, 4> CaseBBs;
5503 
5504   auto FiniCB = [&](InsertPointTy IP) { return Error::success(); };
5505   auto SectionCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
5506     return Error::success();
5507   };
5508   SectionCBVector.push_back(SectionCB);
5509 
5510   auto PrivCB = [](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5511                    llvm::Value &, llvm::Value &Val,
5512                    llvm::Value *&ReplVal) { return CodeGenIP; };
5513   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5514                                     F->getEntryBlock().getFirstInsertionPt());
5515   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createSections(
5516       Loc, AllocaIP, SectionCBVector, PrivCB, FiniCB, false, false);
5517   assert(AfterIP && "unexpected error");
5518   Builder.restoreIP(*AfterIP);
5519   Builder.CreateRetVoid(); // Required at the end of the function
5520   EXPECT_NE(F->getEntryBlock().getTerminator(), nullptr);
5521   EXPECT_FALSE(verifyModule(*M, &errs()));
5522 }
5523 
5524 TEST_F(OpenMPIRBuilderTest, CreateSections) {
5525   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5526   using BodyGenCallbackTy = llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
5527   OpenMPIRBuilder OMPBuilder(*M);
5528   OMPBuilder.initialize();
5529   F->setName("func");
5530   IRBuilder<> Builder(BB);
5531 
5532   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5533   llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;
5534   llvm::SmallVector<BasicBlock *, 4> CaseBBs;
5535 
5536   BasicBlock *SwitchBB = nullptr;
5537   AllocaInst *PrivAI = nullptr;
5538   SwitchInst *Switch = nullptr;
5539 
5540   unsigned NumBodiesGenerated = 0;
5541   unsigned NumFiniCBCalls = 0;
5542   PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
5543 
5544   auto FiniCB = [&](InsertPointTy IP) {
5545     ++NumFiniCBCalls;
5546     BasicBlock *IPBB = IP.getBlock();
5547     EXPECT_NE(IPBB->end(), IP.getPoint());
5548   };
5549 
5550   auto SectionCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
5551     ++NumBodiesGenerated;
5552     CaseBBs.push_back(CodeGenIP.getBlock());
5553     SwitchBB = CodeGenIP.getBlock()->getSinglePredecessor();
5554     Builder.restoreIP(CodeGenIP);
5555     Builder.CreateStore(F->arg_begin(), PrivAI);
5556     Value *PrivLoad =
5557         Builder.CreateLoad(F->arg_begin()->getType(), PrivAI, "local.alloca");
5558     Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
5559     return Error::success();
5560   };
5561   auto PrivCB = [](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5562                    llvm::Value &, llvm::Value &Val, llvm::Value *&ReplVal) {
5563     // TODO: Privatization not implemented yet
5564     return CodeGenIP;
5565   };
5566 
5567   SectionCBVector.push_back(SectionCB);
5568   SectionCBVector.push_back(SectionCB);
5569 
5570   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5571                                     F->getEntryBlock().getFirstInsertionPt());
5572   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5573       OMPBuilder.createSections(Loc, AllocaIP, SectionCBVector, PrivCB,
5574                                 FINICB_WRAPPER(FiniCB), false, false);
5575   assert(AfterIP && "unexpected error");
5576   Builder.restoreIP(*AfterIP);
5577   Builder.CreateRetVoid(); // Required at the end of the function
5578 
5579   // Switch BB's predecessor is loop condition BB, whose successor at index 1 is
5580   // loop's exit BB
5581   BasicBlock *ForExitBB =
5582       SwitchBB->getSinglePredecessor()->getTerminator()->getSuccessor(1);
5583   EXPECT_NE(ForExitBB, nullptr);
5584 
5585   EXPECT_NE(PrivAI, nullptr);
5586   Function *OutlinedFn = PrivAI->getFunction();
5587   EXPECT_EQ(F, OutlinedFn);
5588   EXPECT_FALSE(verifyModule(*M, &errs()));
5589   EXPECT_EQ(OutlinedFn->arg_size(), 1U);
5590 
5591   BasicBlock *LoopPreheaderBB =
5592       OutlinedFn->getEntryBlock().getSingleSuccessor();
5593   // loop variables are 5 - lower bound, upper bound, stride, islastiter, and
5594   // iterator/counter
5595   bool FoundForInit = false;
5596   for (Instruction &Inst : *LoopPreheaderBB) {
5597     if (isa<CallInst>(Inst)) {
5598       if (cast<CallInst>(&Inst)->getCalledFunction()->getName() ==
5599           "__kmpc_for_static_init_4u") {
5600         FoundForInit = true;
5601       }
5602     }
5603   }
5604   EXPECT_EQ(FoundForInit, true);
5605 
5606   bool FoundForExit = false;
5607   bool FoundBarrier = false;
5608   for (Instruction &Inst : *ForExitBB) {
5609     if (isa<CallInst>(Inst)) {
5610       if (cast<CallInst>(&Inst)->getCalledFunction()->getName() ==
5611           "__kmpc_for_static_fini") {
5612         FoundForExit = true;
5613       }
5614       if (cast<CallInst>(&Inst)->getCalledFunction()->getName() ==
5615           "__kmpc_barrier") {
5616         FoundBarrier = true;
5617       }
5618       if (FoundForExit && FoundBarrier)
5619         break;
5620     }
5621   }
5622   EXPECT_EQ(FoundForExit, true);
5623   EXPECT_EQ(FoundBarrier, true);
5624 
5625   EXPECT_NE(SwitchBB, nullptr);
5626   EXPECT_NE(SwitchBB->getTerminator(), nullptr);
5627   EXPECT_EQ(isa<SwitchInst>(SwitchBB->getTerminator()), true);
5628   Switch = cast<SwitchInst>(SwitchBB->getTerminator());
5629   EXPECT_EQ(Switch->getNumCases(), 2U);
5630 
5631   EXPECT_EQ(CaseBBs.size(), 2U);
5632   for (auto *&CaseBB : CaseBBs) {
5633     EXPECT_EQ(CaseBB->getParent(), OutlinedFn);
5634   }
5635 
5636   ASSERT_EQ(NumBodiesGenerated, 2U);
5637   ASSERT_EQ(NumFiniCBCalls, 1U);
5638   EXPECT_FALSE(verifyModule(*M, &errs()));
5639 }
5640 
5641 TEST_F(OpenMPIRBuilderTest, CreateSectionsNoWait) {
5642   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5643   using BodyGenCallbackTy = llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
5644   OpenMPIRBuilder OMPBuilder(*M);
5645   OMPBuilder.initialize();
5646   F->setName("func");
5647   IRBuilder<> Builder(BB);
5648 
5649   BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F);
5650   Builder.CreateBr(EnterBB);
5651   Builder.SetInsertPoint(EnterBB);
5652   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5653 
5654   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5655                                     F->getEntryBlock().getFirstInsertionPt());
5656   llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;
5657   auto PrivCB = [](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5658                    llvm::Value &, llvm::Value &Val,
5659                    llvm::Value *&ReplVal) { return CodeGenIP; };
5660   auto FiniCB = [&](InsertPointTy IP) { return Error::success(); };
5661 
5662   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createSections(
5663       Loc, AllocaIP, SectionCBVector, PrivCB, FiniCB, false, true);
5664   assert(AfterIP && "unexpected error");
5665   Builder.restoreIP(*AfterIP);
5666   Builder.CreateRetVoid(); // Required at the end of the function
5667   for (auto &Inst : instructions(*F)) {
5668     EXPECT_FALSE(isa<CallInst>(Inst) &&
5669                  cast<CallInst>(&Inst)->getCalledFunction()->getName() ==
5670                      "__kmpc_barrier" &&
5671                  "call to function __kmpc_barrier found with nowait");
5672   }
5673 }
5674 
5675 TEST_F(OpenMPIRBuilderTest, CreateOffloadMaptypes) {
5676   OpenMPIRBuilder OMPBuilder(*M);
5677   OMPBuilder.initialize();
5678 
5679   IRBuilder<> Builder(BB);
5680 
5681   SmallVector<uint64_t> Mappings = {0, 1};
5682   GlobalVariable *OffloadMaptypesGlobal =
5683       OMPBuilder.createOffloadMaptypes(Mappings, "offload_maptypes");
5684   EXPECT_FALSE(M->global_empty());
5685   EXPECT_EQ(OffloadMaptypesGlobal->getName(), "offload_maptypes");
5686   EXPECT_TRUE(OffloadMaptypesGlobal->isConstant());
5687   EXPECT_TRUE(OffloadMaptypesGlobal->hasGlobalUnnamedAddr());
5688   EXPECT_TRUE(OffloadMaptypesGlobal->hasPrivateLinkage());
5689   EXPECT_TRUE(OffloadMaptypesGlobal->hasInitializer());
5690   Constant *Initializer = OffloadMaptypesGlobal->getInitializer();
5691   EXPECT_TRUE(isa<ConstantDataArray>(Initializer));
5692   ConstantDataArray *MappingInit = dyn_cast<ConstantDataArray>(Initializer);
5693   EXPECT_EQ(MappingInit->getNumElements(), Mappings.size());
5694   EXPECT_TRUE(MappingInit->getType()->getElementType()->isIntegerTy(64));
5695   Constant *CA = ConstantDataArray::get(Builder.getContext(), Mappings);
5696   EXPECT_EQ(MappingInit, CA);
5697 }
5698 
5699 TEST_F(OpenMPIRBuilderTest, CreateOffloadMapnames) {
5700   OpenMPIRBuilder OMPBuilder(*M);
5701   OMPBuilder.initialize();
5702 
5703   IRBuilder<> Builder(BB);
5704 
5705   uint32_t StrSize;
5706   Constant *Cst1 =
5707       OMPBuilder.getOrCreateSrcLocStr("array1", "file1", 2, 5, StrSize);
5708   Constant *Cst2 =
5709       OMPBuilder.getOrCreateSrcLocStr("array2", "file1", 3, 5, StrSize);
5710   SmallVector<llvm::Constant *> Names = {Cst1, Cst2};
5711 
5712   GlobalVariable *OffloadMaptypesGlobal =
5713       OMPBuilder.createOffloadMapnames(Names, "offload_mapnames");
5714   EXPECT_FALSE(M->global_empty());
5715   EXPECT_EQ(OffloadMaptypesGlobal->getName(), "offload_mapnames");
5716   EXPECT_TRUE(OffloadMaptypesGlobal->isConstant());
5717   EXPECT_FALSE(OffloadMaptypesGlobal->hasGlobalUnnamedAddr());
5718   EXPECT_TRUE(OffloadMaptypesGlobal->hasPrivateLinkage());
5719   EXPECT_TRUE(OffloadMaptypesGlobal->hasInitializer());
5720   Constant *Initializer = OffloadMaptypesGlobal->getInitializer();
5721   EXPECT_TRUE(isa<Constant>(Initializer->getOperand(0)->stripPointerCasts()));
5722   EXPECT_TRUE(isa<Constant>(Initializer->getOperand(1)->stripPointerCasts()));
5723 
5724   GlobalVariable *Name1Gbl =
5725       cast<GlobalVariable>(Initializer->getOperand(0)->stripPointerCasts());
5726   EXPECT_TRUE(isa<ConstantDataArray>(Name1Gbl->getInitializer()));
5727   ConstantDataArray *Name1GblCA =
5728       dyn_cast<ConstantDataArray>(Name1Gbl->getInitializer());
5729   EXPECT_EQ(Name1GblCA->getAsCString(), ";file1;array1;2;5;;");
5730 
5731   GlobalVariable *Name2Gbl =
5732       cast<GlobalVariable>(Initializer->getOperand(1)->stripPointerCasts());
5733   EXPECT_TRUE(isa<ConstantDataArray>(Name2Gbl->getInitializer()));
5734   ConstantDataArray *Name2GblCA =
5735       dyn_cast<ConstantDataArray>(Name2Gbl->getInitializer());
5736   EXPECT_EQ(Name2GblCA->getAsCString(), ";file1;array2;3;5;;");
5737 
5738   EXPECT_TRUE(Initializer->getType()->getArrayElementType()->isPointerTy());
5739   EXPECT_EQ(Initializer->getType()->getArrayNumElements(), Names.size());
5740 }
5741 
5742 TEST_F(OpenMPIRBuilderTest, CreateMapperAllocas) {
5743   OpenMPIRBuilder OMPBuilder(*M);
5744   OMPBuilder.initialize();
5745   F->setName("func");
5746   IRBuilder<> Builder(BB);
5747 
5748   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5749 
5750   unsigned TotalNbOperand = 2;
5751 
5752   OpenMPIRBuilder::MapperAllocas MapperAllocas;
5753   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5754                                     F->getEntryBlock().getFirstInsertionPt());
5755   OMPBuilder.createMapperAllocas(Loc, AllocaIP, TotalNbOperand, MapperAllocas);
5756   EXPECT_NE(MapperAllocas.ArgsBase, nullptr);
5757   EXPECT_NE(MapperAllocas.Args, nullptr);
5758   EXPECT_NE(MapperAllocas.ArgSizes, nullptr);
5759   EXPECT_TRUE(MapperAllocas.ArgsBase->getAllocatedType()->isArrayTy());
5760   ArrayType *ArrType =
5761       dyn_cast<ArrayType>(MapperAllocas.ArgsBase->getAllocatedType());
5762   EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
5763   EXPECT_TRUE(MapperAllocas.ArgsBase->getAllocatedType()
5764                   ->getArrayElementType()
5765                   ->isPointerTy());
5766 
5767   EXPECT_TRUE(MapperAllocas.Args->getAllocatedType()->isArrayTy());
5768   ArrType = dyn_cast<ArrayType>(MapperAllocas.Args->getAllocatedType());
5769   EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
5770   EXPECT_TRUE(MapperAllocas.Args->getAllocatedType()
5771                   ->getArrayElementType()
5772                   ->isPointerTy());
5773 
5774   EXPECT_TRUE(MapperAllocas.ArgSizes->getAllocatedType()->isArrayTy());
5775   ArrType = dyn_cast<ArrayType>(MapperAllocas.ArgSizes->getAllocatedType());
5776   EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
5777   EXPECT_TRUE(MapperAllocas.ArgSizes->getAllocatedType()
5778                   ->getArrayElementType()
5779                   ->isIntegerTy(64));
5780 }
5781 
5782 TEST_F(OpenMPIRBuilderTest, EmitMapperCall) {
5783   OpenMPIRBuilder OMPBuilder(*M);
5784   OMPBuilder.initialize();
5785   F->setName("func");
5786   IRBuilder<> Builder(BB);
5787   LLVMContext &Ctx = M->getContext();
5788 
5789   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5790 
5791   unsigned TotalNbOperand = 2;
5792 
5793   OpenMPIRBuilder::MapperAllocas MapperAllocas;
5794   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5795                                     F->getEntryBlock().getFirstInsertionPt());
5796   OMPBuilder.createMapperAllocas(Loc, AllocaIP, TotalNbOperand, MapperAllocas);
5797 
5798   auto *BeginMapperFunc = OMPBuilder.getOrCreateRuntimeFunctionPtr(
5799       omp::OMPRTL___tgt_target_data_begin_mapper);
5800 
5801   SmallVector<uint64_t> Flags = {0, 2};
5802 
5803   uint32_t StrSize;
5804   Constant *SrcLocCst =
5805       OMPBuilder.getOrCreateSrcLocStr("", "file1", 2, 5, StrSize);
5806   Value *SrcLocInfo = OMPBuilder.getOrCreateIdent(SrcLocCst, StrSize);
5807 
5808   Constant *Cst1 =
5809       OMPBuilder.getOrCreateSrcLocStr("array1", "file1", 2, 5, StrSize);
5810   Constant *Cst2 =
5811       OMPBuilder.getOrCreateSrcLocStr("array2", "file1", 3, 5, StrSize);
5812   SmallVector<llvm::Constant *> Names = {Cst1, Cst2};
5813 
5814   GlobalVariable *Maptypes =
5815       OMPBuilder.createOffloadMaptypes(Flags, ".offload_maptypes");
5816   Value *MaptypesArg = Builder.CreateConstInBoundsGEP2_32(
5817       ArrayType::get(Type::getInt64Ty(Ctx), TotalNbOperand), Maptypes,
5818       /*Idx0=*/0, /*Idx1=*/0);
5819 
5820   GlobalVariable *Mapnames =
5821       OMPBuilder.createOffloadMapnames(Names, ".offload_mapnames");
5822   Value *MapnamesArg = Builder.CreateConstInBoundsGEP2_32(
5823       ArrayType::get(PointerType::getUnqual(Ctx), TotalNbOperand), Mapnames,
5824       /*Idx0=*/0, /*Idx1=*/0);
5825 
5826   OMPBuilder.emitMapperCall(Builder.saveIP(), BeginMapperFunc, SrcLocInfo,
5827                             MaptypesArg, MapnamesArg, MapperAllocas, -1,
5828                             TotalNbOperand);
5829 
5830   CallInst *MapperCall = dyn_cast<CallInst>(&BB->back());
5831   EXPECT_NE(MapperCall, nullptr);
5832   EXPECT_EQ(MapperCall->arg_size(), 9U);
5833   EXPECT_EQ(MapperCall->getCalledFunction()->getName(),
5834             "__tgt_target_data_begin_mapper");
5835   EXPECT_EQ(MapperCall->getOperand(0), SrcLocInfo);
5836   EXPECT_TRUE(MapperCall->getOperand(1)->getType()->isIntegerTy(64));
5837   EXPECT_TRUE(MapperCall->getOperand(2)->getType()->isIntegerTy(32));
5838 
5839   EXPECT_EQ(MapperCall->getOperand(6), MaptypesArg);
5840   EXPECT_EQ(MapperCall->getOperand(7), MapnamesArg);
5841   EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy());
5842 }
5843 
5844 TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
5845   OpenMPIRBuilder OMPBuilder(*M);
5846   OMPBuilder.initialize();
5847   F->setName("func");
5848   IRBuilder<> Builder(BB);
5849   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5850 
5851   int64_t DeviceID = 2;
5852 
5853   AllocaInst *Val1 =
5854       Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1));
5855   ASSERT_NE(Val1, nullptr);
5856 
5857   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5858                                     F->getEntryBlock().getFirstInsertionPt());
5859 
5860   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo;
5861   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5862   auto GenMapInfoCB =
5863       [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
5864     // Get map clause information.
5865     Builder.restoreIP(codeGenIP);
5866 
5867     CombinedInfo.BasePointers.emplace_back(Val1);
5868     CombinedInfo.Pointers.emplace_back(Val1);
5869     CombinedInfo.DevicePointers.emplace_back(
5870         llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5871     CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
5872     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(1));
5873     uint32_t temp;
5874     CombinedInfo.Names.emplace_back(
5875         OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
5876     return CombinedInfo;
5877   };
5878 
5879   llvm::OpenMPIRBuilder::TargetDataInfo Info(
5880       /*RequiresDevicePointerInfo=*/false,
5881       /*SeparateBeginEndCalls=*/true);
5882 
5883   OMPBuilder.Config.setIsGPU(true);
5884 
5885   llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper;
5886   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
5887       Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
5888       /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
5889   assert(AfterIP && "unexpected error");
5890   Builder.restoreIP(*AfterIP);
5891 
5892   CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back());
5893   EXPECT_NE(TargetDataCall, nullptr);
5894   EXPECT_EQ(TargetDataCall->arg_size(), 9U);
5895   EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(),
5896             "__tgt_target_data_begin_mapper");
5897   EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64));
5898   EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32));
5899   EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy());
5900 
5901   Builder.CreateRetVoid();
5902   EXPECT_FALSE(verifyModule(*M, &errs()));
5903 }
5904 
5905 TEST_F(OpenMPIRBuilderTest, TargetExitData) {
5906   OpenMPIRBuilder OMPBuilder(*M);
5907   OMPBuilder.initialize();
5908   F->setName("func");
5909   IRBuilder<> Builder(BB);
5910   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5911 
5912   int64_t DeviceID = 2;
5913 
5914   AllocaInst *Val1 =
5915       Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1));
5916   ASSERT_NE(Val1, nullptr);
5917 
5918   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5919                                     F->getEntryBlock().getFirstInsertionPt());
5920 
5921   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo;
5922   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5923   auto GenMapInfoCB =
5924       [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
5925     // Get map clause information.
5926     Builder.restoreIP(codeGenIP);
5927 
5928     CombinedInfo.BasePointers.emplace_back(Val1);
5929     CombinedInfo.Pointers.emplace_back(Val1);
5930     CombinedInfo.DevicePointers.emplace_back(
5931         llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5932     CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
5933     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(2));
5934     uint32_t temp;
5935     CombinedInfo.Names.emplace_back(
5936         OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
5937     return CombinedInfo;
5938   };
5939 
5940   llvm::OpenMPIRBuilder::TargetDataInfo Info(
5941       /*RequiresDevicePointerInfo=*/false,
5942       /*SeparateBeginEndCalls=*/true);
5943 
5944   OMPBuilder.Config.setIsGPU(true);
5945 
5946   llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_end_mapper;
5947   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
5948       Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
5949       /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
5950   assert(AfterIP && "unexpected error");
5951   Builder.restoreIP(*AfterIP);
5952 
5953   CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back());
5954   EXPECT_NE(TargetDataCall, nullptr);
5955   EXPECT_EQ(TargetDataCall->arg_size(), 9U);
5956   EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(),
5957             "__tgt_target_data_end_mapper");
5958   EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64));
5959   EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32));
5960   EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy());
5961 
5962   Builder.CreateRetVoid();
5963   EXPECT_FALSE(verifyModule(*M, &errs()));
5964 }
5965 
5966 TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
5967   OpenMPIRBuilder OMPBuilder(*M);
5968   OMPBuilder.initialize();
5969   F->setName("func");
5970   IRBuilder<> Builder(BB);
5971   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
5972 
5973   int64_t DeviceID = 2;
5974 
5975   AllocaInst *Val1 =
5976       Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1));
5977   ASSERT_NE(Val1, nullptr);
5978 
5979   AllocaInst *Val2 = Builder.CreateAlloca(Builder.getPtrTy());
5980   ASSERT_NE(Val2, nullptr);
5981 
5982   AllocaInst *Val3 = Builder.CreateAlloca(Builder.getPtrTy());
5983   ASSERT_NE(Val3, nullptr);
5984 
5985   IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
5986                                     F->getEntryBlock().getFirstInsertionPt());
5987 
5988   using DeviceInfoTy = llvm::OpenMPIRBuilder::DeviceInfoTy;
5989   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo;
5990   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
5991   auto GenMapInfoCB =
5992       [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
5993     // Get map clause information.
5994     Builder.restoreIP(codeGenIP);
5995     uint32_t temp;
5996 
5997     CombinedInfo.BasePointers.emplace_back(Val1);
5998     CombinedInfo.Pointers.emplace_back(Val1);
5999     CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::None);
6000     CombinedInfo.Sizes.emplace_back(Builder.getInt64(4));
6001     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(3));
6002     CombinedInfo.Names.emplace_back(
6003         OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
6004 
6005     CombinedInfo.BasePointers.emplace_back(Val2);
6006     CombinedInfo.Pointers.emplace_back(Val2);
6007     CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
6008     CombinedInfo.Sizes.emplace_back(Builder.getInt64(8));
6009     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(67));
6010     CombinedInfo.Names.emplace_back(
6011         OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
6012 
6013     CombinedInfo.BasePointers.emplace_back(Val3);
6014     CombinedInfo.Pointers.emplace_back(Val3);
6015     CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Address);
6016     CombinedInfo.Sizes.emplace_back(Builder.getInt64(8));
6017     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(67));
6018     CombinedInfo.Names.emplace_back(
6019         OMPBuilder.getOrCreateSrcLocStr("unknown", temp));
6020     return CombinedInfo;
6021   };
6022 
6023   llvm::OpenMPIRBuilder::TargetDataInfo Info(
6024       /*RequiresDevicePointerInfo=*/true,
6025       /*SeparateBeginEndCalls=*/true);
6026 
6027   OMPBuilder.Config.setIsGPU(true);
6028 
6029   using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6030   auto BodyCB = [&](InsertPointTy CodeGenIP, BodyGenTy BodyGenType) {
6031     if (BodyGenType == BodyGenTy::Priv) {
6032       EXPECT_EQ(Info.DevicePtrInfoMap.size(), 2u);
6033       Builder.restoreIP(CodeGenIP);
6034       CallInst *TargetDataCall =
6035           dyn_cast<CallInst>(BB->back().getPrevNode()->getPrevNode());
6036       EXPECT_NE(TargetDataCall, nullptr);
6037       EXPECT_EQ(TargetDataCall->arg_size(), 9U);
6038       EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(),
6039                 "__tgt_target_data_begin_mapper");
6040       EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64));
6041       EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32));
6042       EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy());
6043 
6044       LoadInst *LI = dyn_cast<LoadInst>(BB->back().getPrevNode());
6045       EXPECT_NE(LI, nullptr);
6046       StoreInst *SI = dyn_cast<StoreInst>(&BB->back());
6047       EXPECT_NE(SI, nullptr);
6048       EXPECT_EQ(SI->getValueOperand(), LI);
6049       EXPECT_EQ(SI->getPointerOperand(), Info.DevicePtrInfoMap[Val2].second);
6050       EXPECT_TRUE(isa<AllocaInst>(Info.DevicePtrInfoMap[Val2].second));
6051       EXPECT_TRUE(isa<GetElementPtrInst>(Info.DevicePtrInfoMap[Val3].second));
6052       Builder.CreateStore(Builder.getInt32(99), Val1);
6053     }
6054     return Builder.saveIP();
6055   };
6056 
6057   OpenMPIRBuilder::InsertPointOrErrorTy TargetDataIP1 =
6058       OMPBuilder.createTargetData(
6059           Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
6060           /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB);
6061   assert(TargetDataIP1 && "unexpected error");
6062   Builder.restoreIP(*TargetDataIP1);
6063 
6064   CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back());
6065   EXPECT_NE(TargetDataCall, nullptr);
6066   EXPECT_EQ(TargetDataCall->arg_size(), 9U);
6067   EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(),
6068             "__tgt_target_data_end_mapper");
6069   EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64));
6070   EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32));
6071   EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy());
6072 
6073   // Check that BodyGenCB is still made when IsTargetDevice is set to true.
6074   OMPBuilder.Config.setIsTargetDevice(true);
6075   bool CheckDevicePassBodyGen = false;
6076   auto BodyTargetCB = [&](InsertPointTy CodeGenIP, BodyGenTy BodyGenType) {
6077     CheckDevicePassBodyGen = true;
6078     Builder.restoreIP(CodeGenIP);
6079     CallInst *TargetDataCall =
6080         dyn_cast<CallInst>(BB->back().getPrevNode()->getPrevNode());
6081     // Make sure no begin_mapper call is present for device pass.
6082     EXPECT_EQ(TargetDataCall, nullptr);
6083     return Builder.saveIP();
6084   };
6085   OpenMPIRBuilder::InsertPointOrErrorTy TargetDataIP2 =
6086       OMPBuilder.createTargetData(
6087           Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
6088           /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyTargetCB);
6089   assert(TargetDataIP2 && "unexpected error");
6090   Builder.restoreIP(*TargetDataIP2);
6091   EXPECT_TRUE(CheckDevicePassBodyGen);
6092 
6093   Builder.CreateRetVoid();
6094   EXPECT_FALSE(verifyModule(*M, &errs()));
6095 }
6096 
6097 namespace {
6098 // Some basic handling of argument mapping for the moment
6099 void CreateDefaultMapInfos(llvm::OpenMPIRBuilder &OmpBuilder,
6100                            llvm::SmallVectorImpl<llvm::Value *> &Args,
6101                            llvm::OpenMPIRBuilder::MapInfosTy &CombinedInfo) {
6102   for (auto Arg : Args) {
6103     CombinedInfo.BasePointers.emplace_back(Arg);
6104     CombinedInfo.Pointers.emplace_back(Arg);
6105     uint32_t SrcLocStrSize;
6106     CombinedInfo.Names.emplace_back(OmpBuilder.getOrCreateSrcLocStr(
6107         "Unknown loc - stub implementation", SrcLocStrSize));
6108     CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(
6109         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
6110         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
6111         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM));
6112     CombinedInfo.Sizes.emplace_back(OmpBuilder.Builder.getInt64(
6113         OmpBuilder.M.getDataLayout().getTypeAllocSize(Arg->getType())));
6114   }
6115 }
6116 } // namespace
6117 
6118 TEST_F(OpenMPIRBuilderTest, TargetRegion) {
6119   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6120   OpenMPIRBuilder OMPBuilder(*M);
6121   OMPBuilder.initialize();
6122   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
6123   OMPBuilder.setConfig(Config);
6124   F->setName("func");
6125   IRBuilder<> Builder(BB);
6126   auto Int32Ty = Builder.getInt32Ty();
6127 
6128   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
6129   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
6130   AllocaInst *CPtr = Builder.CreateAlloca(Int32Ty, nullptr, "c_ptr");
6131 
6132   Builder.CreateStore(Builder.getInt32(10), APtr);
6133   Builder.CreateStore(Builder.getInt32(20), BPtr);
6134   auto BodyGenCB = [&](InsertPointTy AllocaIP,
6135                        InsertPointTy CodeGenIP) -> InsertPointTy {
6136     Builder.restoreIP(CodeGenIP);
6137     LoadInst *AVal = Builder.CreateLoad(Int32Ty, APtr);
6138     LoadInst *BVal = Builder.CreateLoad(Int32Ty, BPtr);
6139     Value *Sum = Builder.CreateAdd(AVal, BVal);
6140     Builder.CreateStore(Sum, CPtr);
6141     return Builder.saveIP();
6142   };
6143 
6144   llvm::SmallVector<llvm::Value *> Inputs;
6145   Inputs.push_back(APtr);
6146   Inputs.push_back(BPtr);
6147   Inputs.push_back(CPtr);
6148 
6149   auto SimpleArgAccessorCB =
6150       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
6151           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
6152           llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
6153         if (!OMPBuilder.Config.isTargetDevice()) {
6154           RetVal = cast<llvm::Value>(&Arg);
6155           return CodeGenIP;
6156         }
6157 
6158         Builder.restoreIP(AllocaIP);
6159 
6160         llvm::Value *Addr = Builder.CreateAlloca(
6161             Arg.getType()->isPointerTy()
6162                 ? Arg.getType()
6163                 : Type::getInt64Ty(Builder.getContext()),
6164             OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
6165         llvm::Value *AddrAscast =
6166             Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
6167         Builder.CreateStore(&Arg, AddrAscast);
6168 
6169         Builder.restoreIP(CodeGenIP);
6170 
6171         RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
6172 
6173         return Builder.saveIP();
6174       };
6175 
6176   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
6177   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
6178       -> llvm::OpenMPIRBuilder::MapInfosTy & {
6179     CreateDefaultMapInfos(OMPBuilder, Inputs, CombinedInfos);
6180     return CombinedInfos;
6181   };
6182 
6183   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
6184   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
6185   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6186       OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
6187       EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
6188   assert(AfterIP && "unexpected error");
6189   Builder.restoreIP(*AfterIP);
6190   OMPBuilder.finalize();
6191   Builder.CreateRetVoid();
6192 
6193   // Check the kernel launch sequence
6194   auto Iter = F->getEntryBlock().rbegin();
6195   EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
6196   BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
6197   EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
6198   EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
6199   CallInst *Call = dyn_cast<CallInst>(&*(Iter));
6200 
6201   // Check that the kernel launch function is called
6202   Function *KernelLaunchFunc = Call->getCalledFunction();
6203   EXPECT_NE(KernelLaunchFunc, nullptr);
6204   StringRef FunctionName = KernelLaunchFunc->getName();
6205   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
6206 
6207   // Check the fallback call
6208   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
6209   Iter = FallbackBlock->rbegin();
6210   CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
6211   // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
6212   // have a DISubprogram. In this case, the call to OutlinedFunc needs
6213   // to have a debug loc, otherwise verifier will complain.
6214   FCall->setDebugLoc(DL);
6215   EXPECT_NE(FCall, nullptr);
6216 
6217   // Check that the correct aguments are passed in
6218   for (auto ArgInput : zip(FCall->args(), Inputs)) {
6219     EXPECT_EQ(std::get<0>(ArgInput), std::get<1>(ArgInput));
6220   }
6221 
6222   // Check that the outlined function exists with the expected prefix
6223   Function *OutlinedFunc = FCall->getCalledFunction();
6224   EXPECT_NE(OutlinedFunc, nullptr);
6225   StringRef FunctionName2 = OutlinedFunc->getName();
6226   EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
6227 
6228   EXPECT_FALSE(verifyModule(*M, &errs()));
6229 }
6230 
6231 TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
6232   OpenMPIRBuilder OMPBuilder(*M);
6233   OMPBuilder.setConfig(
6234       OpenMPIRBuilderConfig(true, false, false, false, false, false, false));
6235   OMPBuilder.initialize();
6236 
6237   F->setName("func");
6238   IRBuilder<> Builder(BB);
6239   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
6240 
6241   LoadInst *Value = nullptr;
6242   StoreInst *TargetStore = nullptr;
6243   llvm::SmallVector<llvm::Value *, 2> CapturedArgs = {
6244       Constant::getNullValue(PointerType::get(Ctx, 0)),
6245       Constant::getNullValue(PointerType::get(Ctx, 0))};
6246 
6247   auto SimpleArgAccessorCB =
6248       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
6249           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
6250           llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
6251         if (!OMPBuilder.Config.isTargetDevice()) {
6252           RetVal = cast<llvm::Value>(&Arg);
6253           return CodeGenIP;
6254         }
6255 
6256         Builder.restoreIP(AllocaIP);
6257 
6258         llvm::Value *Addr = Builder.CreateAlloca(
6259             Arg.getType()->isPointerTy()
6260                 ? Arg.getType()
6261                 : Type::getInt64Ty(Builder.getContext()),
6262             OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
6263         llvm::Value *AddrAscast =
6264             Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
6265         Builder.CreateStore(&Arg, AddrAscast);
6266 
6267         Builder.restoreIP(CodeGenIP);
6268 
6269         RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
6270 
6271         return Builder.saveIP();
6272       };
6273 
6274   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
6275   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
6276       -> llvm::OpenMPIRBuilder::MapInfosTy & {
6277     CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
6278     return CombinedInfos;
6279   };
6280 
6281   auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
6282                        OpenMPIRBuilder::InsertPointTy CodeGenIP)
6283       -> OpenMPIRBuilder::InsertPointTy {
6284     Builder.restoreIP(CodeGenIP);
6285     Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]);
6286     TargetStore = Builder.CreateStore(Value, CapturedArgs[1]);
6287     return Builder.saveIP();
6288   };
6289 
6290   IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
6291                                    F->getEntryBlock().getFirstInsertionPt());
6292   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
6293                                   /*Line=*/3, /*Count=*/0);
6294 
6295   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6296       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6297                               EntryInfo, /*NumTeams=*/-1,
6298                               /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6299                               BodyGenCB, SimpleArgAccessorCB);
6300   assert(AfterIP && "unexpected error");
6301   Builder.restoreIP(*AfterIP);
6302 
6303   Builder.CreateRetVoid();
6304   OMPBuilder.finalize();
6305 
6306   // Check outlined function
6307   EXPECT_FALSE(verifyModule(*M, &errs()));
6308   EXPECT_NE(TargetStore, nullptr);
6309   Function *OutlinedFn = TargetStore->getFunction();
6310   EXPECT_NE(F, OutlinedFn);
6311 
6312   EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
6313   // Account for the "implicit" first argument.
6314   EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
6315   EXPECT_EQ(OutlinedFn->arg_size(), 3U);
6316   EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
6317   EXPECT_TRUE(OutlinedFn->getArg(2)->getType()->isPointerTy());
6318 
6319   // Check entry block
6320   auto &EntryBlock = OutlinedFn->getEntryBlock();
6321   Instruction *Alloca1 = EntryBlock.getFirstNonPHI();
6322   EXPECT_NE(Alloca1, nullptr);
6323 
6324   EXPECT_TRUE(isa<AllocaInst>(Alloca1));
6325   auto *Store1 = Alloca1->getNextNode();
6326   EXPECT_TRUE(isa<StoreInst>(Store1));
6327   auto *Alloca2 = Store1->getNextNode();
6328   EXPECT_TRUE(isa<AllocaInst>(Alloca2));
6329   auto *Store2 = Alloca2->getNextNode();
6330   EXPECT_TRUE(isa<StoreInst>(Store2));
6331 
6332   auto *InitCall = dyn_cast<CallInst>(Store2->getNextNode());
6333   EXPECT_NE(InitCall, nullptr);
6334   EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
6335   EXPECT_EQ(InitCall->arg_size(), 2U);
6336   EXPECT_TRUE(isa<GlobalVariable>(InitCall->getArgOperand(0)));
6337   auto *KernelEnvGV = cast<GlobalVariable>(InitCall->getArgOperand(0));
6338   EXPECT_TRUE(isa<ConstantStruct>(KernelEnvGV->getInitializer()));
6339   auto *KernelEnvC = cast<ConstantStruct>(KernelEnvGV->getInitializer());
6340   EXPECT_TRUE(isa<ConstantStruct>(KernelEnvC->getAggregateElement(0U)));
6341   auto ConfigC = cast<ConstantStruct>(KernelEnvC->getAggregateElement(0U));
6342   EXPECT_EQ(ConfigC->getAggregateElement(0U),
6343             ConstantInt::get(Type::getInt8Ty(Ctx), true));
6344   EXPECT_EQ(ConfigC->getAggregateElement(1U),
6345             ConstantInt::get(Type::getInt8Ty(Ctx), true));
6346   EXPECT_EQ(ConfigC->getAggregateElement(2U),
6347             ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
6348 
6349   auto *EntryBlockBranch = EntryBlock.getTerminator();
6350   EXPECT_NE(EntryBlockBranch, nullptr);
6351   EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U);
6352 
6353   // Check user code block
6354   auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
6355   EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
6356   auto *Load1 = UserCodeBlock->getFirstNonPHI();
6357   EXPECT_TRUE(isa<LoadInst>(Load1));
6358   auto *Load2 = Load1->getNextNode();
6359   EXPECT_TRUE(isa<LoadInst>(Load2));
6360 
6361   auto *Value1 = Load2->getNextNode();
6362   EXPECT_EQ(Value1, Value);
6363   EXPECT_EQ(Value1->getNextNode(), TargetStore);
6364   auto *Deinit = TargetStore->getNextNode();
6365   EXPECT_NE(Deinit, nullptr);
6366 
6367   auto *DeinitCall = dyn_cast<CallInst>(Deinit);
6368   EXPECT_NE(DeinitCall, nullptr);
6369   EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit");
6370   EXPECT_EQ(DeinitCall->arg_size(), 0U);
6371 
6372   EXPECT_TRUE(isa<ReturnInst>(DeinitCall->getNextNode()));
6373 
6374   // Check exit block
6375   auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
6376   EXPECT_EQ(ExitBlock->getName(), "worker.exit");
6377   EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
6378 }
6379 
6380 TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
6381   OpenMPIRBuilder OMPBuilder(*M);
6382   OMPBuilder.setConfig(
6383       OpenMPIRBuilderConfig(true, false, false, false, false, false, false));
6384   OMPBuilder.initialize();
6385 
6386   F->setName("func");
6387   IRBuilder<> Builder(BB);
6388   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
6389 
6390   LoadInst *Value = nullptr;
6391   StoreInst *TargetStore = nullptr;
6392   llvm::SmallVector<llvm::Value *, 1> CapturedArgs = {
6393       Constant::getNullValue(PointerType::get(Ctx, 0))};
6394 
6395   auto SimpleArgAccessorCB =
6396       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
6397           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
6398           llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
6399         if (!OMPBuilder.Config.isTargetDevice()) {
6400           RetVal = cast<llvm::Value>(&Arg);
6401           return CodeGenIP;
6402         }
6403 
6404         Builder.restoreIP(AllocaIP);
6405 
6406         llvm::Value *Addr = Builder.CreateAlloca(
6407             Arg.getType()->isPointerTy()
6408                 ? Arg.getType()
6409                 : Type::getInt64Ty(Builder.getContext()),
6410             OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
6411         llvm::Value *AddrAscast =
6412             Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
6413         Builder.CreateStore(&Arg, AddrAscast);
6414 
6415         Builder.restoreIP(CodeGenIP);
6416 
6417         RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
6418 
6419         return Builder.saveIP();
6420       };
6421 
6422   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
6423   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
6424       -> llvm::OpenMPIRBuilder::MapInfosTy & {
6425     CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
6426     return CombinedInfos;
6427   };
6428 
6429   llvm::Value *RaiseAlloca = nullptr;
6430 
6431   auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
6432                        OpenMPIRBuilder::InsertPointTy CodeGenIP)
6433       -> OpenMPIRBuilder::InsertPointTy {
6434     Builder.restoreIP(CodeGenIP);
6435     RaiseAlloca = Builder.CreateAlloca(Builder.getInt32Ty());
6436     Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]);
6437     TargetStore = Builder.CreateStore(Value, RaiseAlloca);
6438     return Builder.saveIP();
6439   };
6440 
6441   IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
6442                                    F->getEntryBlock().getFirstInsertionPt());
6443   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
6444                                   /*Line=*/3, /*Count=*/0);
6445 
6446   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6447       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6448                               EntryInfo, /*NumTeams=*/-1,
6449                               /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6450                               BodyGenCB, SimpleArgAccessorCB);
6451   assert(AfterIP && "unexpected error");
6452   Builder.restoreIP(*AfterIP);
6453 
6454   Builder.CreateRetVoid();
6455   OMPBuilder.finalize();
6456 
6457   // Check outlined function
6458   EXPECT_FALSE(verifyModule(*M, &errs()));
6459   EXPECT_NE(TargetStore, nullptr);
6460   Function *OutlinedFn = TargetStore->getFunction();
6461   EXPECT_NE(F, OutlinedFn);
6462 
6463   EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
6464   // Account for the "implicit" first argument.
6465   EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
6466   EXPECT_EQ(OutlinedFn->arg_size(), 2U);
6467   EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
6468 
6469   // Check entry block, to see if we have raised our alloca
6470   // from the body to the entry block.
6471   auto &EntryBlock = OutlinedFn->getEntryBlock();
6472 
6473   // Check that we have moved our alloca created in the
6474   // BodyGenCB function, to the top of the function.
6475   Instruction *Alloca1 = EntryBlock.getFirstNonPHI();
6476   EXPECT_NE(Alloca1, nullptr);
6477   EXPECT_TRUE(isa<AllocaInst>(Alloca1));
6478   EXPECT_EQ(Alloca1, RaiseAlloca);
6479 
6480   // Verify we have not altered the rest of the function
6481   // inappropriately with our alloca movement.
6482   auto *Alloca2 = Alloca1->getNextNode();
6483   EXPECT_TRUE(isa<AllocaInst>(Alloca2));
6484   auto *Store2 = Alloca2->getNextNode();
6485   EXPECT_TRUE(isa<StoreInst>(Store2));
6486 
6487   auto *InitCall = dyn_cast<CallInst>(Store2->getNextNode());
6488   EXPECT_NE(InitCall, nullptr);
6489   EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
6490   EXPECT_EQ(InitCall->arg_size(), 2U);
6491   EXPECT_TRUE(isa<GlobalVariable>(InitCall->getArgOperand(0)));
6492   auto *KernelEnvGV = cast<GlobalVariable>(InitCall->getArgOperand(0));
6493   EXPECT_TRUE(isa<ConstantStruct>(KernelEnvGV->getInitializer()));
6494   auto *KernelEnvC = cast<ConstantStruct>(KernelEnvGV->getInitializer());
6495   EXPECT_TRUE(isa<ConstantStruct>(KernelEnvC->getAggregateElement(0U)));
6496   auto *ConfigC = cast<ConstantStruct>(KernelEnvC->getAggregateElement(0U));
6497   EXPECT_EQ(ConfigC->getAggregateElement(0U),
6498             ConstantInt::get(Type::getInt8Ty(Ctx), true));
6499   EXPECT_EQ(ConfigC->getAggregateElement(1U),
6500             ConstantInt::get(Type::getInt8Ty(Ctx), true));
6501   EXPECT_EQ(ConfigC->getAggregateElement(2U),
6502             ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
6503 
6504   auto *EntryBlockBranch = EntryBlock.getTerminator();
6505   EXPECT_NE(EntryBlockBranch, nullptr);
6506   EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U);
6507 
6508   // Check user code block
6509   auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
6510   EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
6511   auto *Load1 = UserCodeBlock->getFirstNonPHI();
6512   EXPECT_TRUE(isa<LoadInst>(Load1));
6513   auto *Load2 = Load1->getNextNode();
6514   EXPECT_TRUE(isa<LoadInst>(Load2));
6515   EXPECT_EQ(Load2, Value);
6516   EXPECT_EQ(Load2->getNextNode(), TargetStore);
6517   auto *Deinit = TargetStore->getNextNode();
6518   EXPECT_NE(Deinit, nullptr);
6519 
6520   auto *DeinitCall = dyn_cast<CallInst>(Deinit);
6521   EXPECT_NE(DeinitCall, nullptr);
6522   EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit");
6523   EXPECT_EQ(DeinitCall->arg_size(), 0U);
6524 
6525   EXPECT_TRUE(isa<ReturnInst>(DeinitCall->getNextNode()));
6526 
6527   // Check exit block
6528   auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
6529   EXPECT_EQ(ExitBlock->getName(), "worker.exit");
6530   EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
6531 }
6532 
6533 TEST_F(OpenMPIRBuilderTest, CreateTask) {
6534   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6535   OpenMPIRBuilder OMPBuilder(*M);
6536   OMPBuilder.Config.IsTargetDevice = false;
6537   OMPBuilder.initialize();
6538   F->setName("func");
6539   IRBuilder<> Builder(BB);
6540 
6541   AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
6542   AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
6543   Value *Val128 =
6544       Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
6545 
6546   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6547     Builder.restoreIP(AllocaIP);
6548     AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
6549                                                 "bodygen.alloca128");
6550 
6551     Builder.restoreIP(CodeGenIP);
6552     // Loading and storing captured pointer and values
6553     Builder.CreateStore(Val128, Local128);
6554     Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
6555                                       "bodygen.load32");
6556 
6557     LoadInst *PrivLoad128 = Builder.CreateLoad(
6558         Local128->getAllocatedType(), Local128, "bodygen.local.load128");
6559     Value *Cmp = Builder.CreateICmpNE(
6560         Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
6561     Instruction *ThenTerm, *ElseTerm;
6562     SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
6563                                   &ThenTerm, &ElseTerm);
6564     return Error::success();
6565   };
6566 
6567   BasicBlock *AllocaBB = Builder.GetInsertBlock();
6568   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6569   OpenMPIRBuilder::LocationDescription Loc(
6570       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
6571   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTask(
6572       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB);
6573   assert(AfterIP && "unexpected error");
6574   Builder.restoreIP(*AfterIP);
6575   OMPBuilder.finalize();
6576   Builder.CreateRetVoid();
6577 
6578   EXPECT_FALSE(verifyModule(*M, &errs()));
6579 
6580   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6581       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
6582           ->user_back());
6583 
6584   // Verify the Ident argument
6585   GlobalVariable *Ident = cast<GlobalVariable>(TaskAllocCall->getArgOperand(0));
6586   ASSERT_NE(Ident, nullptr);
6587   EXPECT_TRUE(Ident->hasInitializer());
6588   Constant *Initializer = Ident->getInitializer();
6589   GlobalVariable *SrcStrGlob =
6590       cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
6591   ASSERT_NE(SrcStrGlob, nullptr);
6592   ConstantDataArray *SrcSrc =
6593       dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
6594   ASSERT_NE(SrcSrc, nullptr);
6595 
6596   // Verify the num_threads argument.
6597   CallInst *GTID = dyn_cast<CallInst>(TaskAllocCall->getArgOperand(1));
6598   ASSERT_NE(GTID, nullptr);
6599   EXPECT_EQ(GTID->arg_size(), 1U);
6600   EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
6601 
6602   // Verify the flags
6603   // TODO: Check for others flags. Currently testing only for tiedness.
6604   ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
6605   ASSERT_NE(Flags, nullptr);
6606   EXPECT_EQ(Flags->getSExtValue(), 1);
6607 
6608   // Verify the data size
6609   ConstantInt *DataSize =
6610       dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
6611   ASSERT_NE(DataSize, nullptr);
6612   EXPECT_EQ(DataSize->getSExtValue(), 40);
6613 
6614   ConstantInt *SharedsSize =
6615       dyn_cast<ConstantInt>(TaskAllocCall->getOperand(4));
6616   EXPECT_EQ(SharedsSize->getSExtValue(),
6617             24); // 64-bit pointer + 128-bit integer
6618 
6619   // Verify Wrapper function
6620   Function *OutlinedFn =
6621       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
6622   ASSERT_NE(OutlinedFn, nullptr);
6623 
6624   LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin());
6625   ASSERT_NE(SharedsLoad, nullptr);
6626   EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1));
6627 
6628   EXPECT_FALSE(OutlinedFn->isDeclaration());
6629   EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty());
6630 
6631   // Verify that the data argument is used only once, and that too in the load
6632   // instruction that is then used for accessing shared data.
6633   Value *DataPtr = OutlinedFn->getArg(1);
6634   EXPECT_EQ(DataPtr->getNumUses(), 1U);
6635   EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser()));
6636   Value *Data = DataPtr->uses().begin()->getUser();
6637   EXPECT_TRUE(all_of(Data->uses(), [](Use &U) {
6638     return isa<GetElementPtrInst>(U.getUser());
6639   }));
6640 
6641   // Verify the presence of `trunc` and `icmp` instructions in Outlined function
6642   EXPECT_TRUE(any_of(instructions(OutlinedFn),
6643                      [](Instruction &inst) { return isa<TruncInst>(&inst); }));
6644   EXPECT_TRUE(any_of(instructions(OutlinedFn),
6645                      [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
6646 
6647   // Verify the execution of the task
6648   CallInst *TaskCall = dyn_cast<CallInst>(
6649       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
6650           ->user_back());
6651   ASSERT_NE(TaskCall, nullptr);
6652   EXPECT_EQ(TaskCall->getArgOperand(0), Ident);
6653   EXPECT_EQ(TaskCall->getArgOperand(1), GTID);
6654   EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall);
6655 
6656   // Verify that the argument data has been copied
6657   for (User *in : TaskAllocCall->users()) {
6658     if (MemCpyInst *memCpyInst = dyn_cast<MemCpyInst>(in)) {
6659       EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall);
6660     }
6661   }
6662 }
6663 
6664 TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
6665   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6666   OpenMPIRBuilder OMPBuilder(*M);
6667   OMPBuilder.Config.IsTargetDevice = false;
6668   OMPBuilder.initialize();
6669   F->setName("func");
6670   IRBuilder<> Builder(BB);
6671 
6672   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6673     return Error::success();
6674   };
6675 
6676   BasicBlock *AllocaBB = Builder.GetInsertBlock();
6677   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6678   OpenMPIRBuilder::LocationDescription Loc(
6679       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
6680   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTask(
6681       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB);
6682   assert(AfterIP && "unexpected error");
6683   Builder.restoreIP(*AfterIP);
6684   OMPBuilder.finalize();
6685   Builder.CreateRetVoid();
6686 
6687   EXPECT_FALSE(verifyModule(*M, &errs()));
6688 
6689   // Check that the outlined function has only one argument.
6690   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6691       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
6692           ->user_back());
6693   Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5));
6694   ASSERT_NE(OutlinedFn, nullptr);
6695   ASSERT_EQ(OutlinedFn->arg_size(), 1U);
6696 }
6697 
6698 TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
6699   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6700   OpenMPIRBuilder OMPBuilder(*M);
6701   OMPBuilder.Config.IsTargetDevice = false;
6702   OMPBuilder.initialize();
6703   F->setName("func");
6704   IRBuilder<> Builder(BB);
6705   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6706     return Error::success();
6707   };
6708   BasicBlock *AllocaBB = Builder.GetInsertBlock();
6709   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6710   OpenMPIRBuilder::LocationDescription Loc(
6711       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
6712   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTask(
6713       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
6714       /*Tied=*/false);
6715   assert(AfterIP && "unexpected error");
6716   Builder.restoreIP(*AfterIP);
6717   OMPBuilder.finalize();
6718   Builder.CreateRetVoid();
6719 
6720   // Check for the `Tied` argument
6721   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6722       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
6723           ->user_back());
6724   ASSERT_NE(TaskAllocCall, nullptr);
6725   ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
6726   ASSERT_NE(Flags, nullptr);
6727   EXPECT_EQ(Flags->getZExtValue() & 1U, 0U);
6728 
6729   EXPECT_FALSE(verifyModule(*M, &errs()));
6730 }
6731 
6732 TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) {
6733   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6734   OpenMPIRBuilder OMPBuilder(*M);
6735   OMPBuilder.Config.IsTargetDevice = false;
6736   OMPBuilder.initialize();
6737   F->setName("func");
6738   IRBuilder<> Builder(BB);
6739   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6740     return Error::success();
6741   };
6742   BasicBlock *AllocaBB = Builder.GetInsertBlock();
6743   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6744   OpenMPIRBuilder::LocationDescription Loc(
6745       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
6746   AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext()));
6747   SmallVector<OpenMPIRBuilder::DependData> DDS;
6748   {
6749     OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn,
6750                                      Type::getInt32Ty(M->getContext()), InDep);
6751     DDS.push_back(DDIn);
6752   }
6753   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTask(
6754       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
6755       /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS);
6756   assert(AfterIP && "unexpected error");
6757   Builder.restoreIP(*AfterIP);
6758   OMPBuilder.finalize();
6759   Builder.CreateRetVoid();
6760 
6761   // Check for the `NumDeps` argument
6762   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6763       OMPBuilder
6764           .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps)
6765           ->user_back());
6766   ASSERT_NE(TaskAllocCall, nullptr);
6767   ConstantInt *NumDeps = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
6768   ASSERT_NE(NumDeps, nullptr);
6769   EXPECT_EQ(NumDeps->getZExtValue(), 1U);
6770 
6771   // Check for the `DepInfo` array argument
6772   AllocaInst *DepArray = dyn_cast<AllocaInst>(TaskAllocCall->getOperand(4));
6773   ASSERT_NE(DepArray, nullptr);
6774   Value::user_iterator DepArrayI = DepArray->user_begin();
6775   ++DepArrayI;
6776   Value::user_iterator DepInfoI = DepArrayI->user_begin();
6777   // Check for the `DependKind` flag in the `DepInfo` array
6778   Value *Flag = findStoredValue<GetElementPtrInst>(*DepInfoI);
6779   ASSERT_NE(Flag, nullptr);
6780   ConstantInt *FlagInt = dyn_cast<ConstantInt>(Flag);
6781   ASSERT_NE(FlagInt, nullptr);
6782   EXPECT_EQ(FlagInt->getZExtValue(),
6783             static_cast<unsigned int>(RTLDependenceKindTy::DepIn));
6784   ++DepInfoI;
6785   // Check for the size in the `DepInfo` array
6786   Value *Size = findStoredValue<GetElementPtrInst>(*DepInfoI);
6787   ASSERT_NE(Size, nullptr);
6788   ConstantInt *SizeInt = dyn_cast<ConstantInt>(Size);
6789   ASSERT_NE(SizeInt, nullptr);
6790   EXPECT_EQ(SizeInt->getZExtValue(), 4U);
6791   ++DepInfoI;
6792   // Check for the variable address in the `DepInfo` array
6793   Value *AddrStored = findStoredValue<GetElementPtrInst>(*DepInfoI);
6794   ASSERT_NE(AddrStored, nullptr);
6795   PtrToIntInst *AddrInt = dyn_cast<PtrToIntInst>(AddrStored);
6796   ASSERT_NE(AddrInt, nullptr);
6797   Value *Addr = AddrInt->getPointerOperand();
6798   EXPECT_EQ(Addr, InDep);
6799 
6800   ConstantInt *NumDepsNoAlias =
6801       dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(5));
6802   ASSERT_NE(NumDepsNoAlias, nullptr);
6803   EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U);
6804   EXPECT_EQ(TaskAllocCall->getOperand(6),
6805             ConstantPointerNull::get(PointerType::getUnqual(M->getContext())));
6806 
6807   EXPECT_FALSE(verifyModule(*M, &errs()));
6808 }
6809 
6810 TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
6811   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6812   OpenMPIRBuilder OMPBuilder(*M);
6813   OMPBuilder.Config.IsTargetDevice = false;
6814   OMPBuilder.initialize();
6815   F->setName("func");
6816   IRBuilder<> Builder(BB);
6817   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6818     return Error::success();
6819   };
6820   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6821   IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
6822   Builder.SetInsertPoint(BodyBB);
6823   Value *Final = Builder.CreateICmp(
6824       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
6825       ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
6826   OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
6827   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6828       OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
6829                             /*Tied=*/false, Final);
6830   assert(AfterIP && "unexpected error");
6831   Builder.restoreIP(*AfterIP);
6832   OMPBuilder.finalize();
6833   Builder.CreateRetVoid();
6834 
6835   // Check for the `Tied` argument
6836   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6837       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
6838           ->user_back());
6839   ASSERT_NE(TaskAllocCall, nullptr);
6840   BinaryOperator *OrInst =
6841       dyn_cast<BinaryOperator>(TaskAllocCall->getArgOperand(2));
6842   ASSERT_NE(OrInst, nullptr);
6843   EXPECT_EQ(OrInst->getOpcode(), BinaryOperator::BinaryOps::Or);
6844 
6845   // One of the arguments to `or` instruction is the tied flag, which is equal
6846   // to zero.
6847   EXPECT_TRUE(any_of(OrInst->operands(), [](Value *op) {
6848     if (ConstantInt *TiedValue = dyn_cast<ConstantInt>(op))
6849       return TiedValue->getSExtValue() == 0;
6850     return false;
6851   }));
6852 
6853   // One of the arguments to `or` instruction is the final condition.
6854   EXPECT_TRUE(any_of(OrInst->operands(), [Final](Value *op) {
6855     if (SelectInst *Select = dyn_cast<SelectInst>(op)) {
6856       ConstantInt *TrueValue = dyn_cast<ConstantInt>(Select->getTrueValue());
6857       ConstantInt *FalseValue = dyn_cast<ConstantInt>(Select->getFalseValue());
6858       if (!TrueValue || !FalseValue)
6859         return false;
6860       return Select->getCondition() == Final &&
6861              TrueValue->getSExtValue() == 2 && FalseValue->getSExtValue() == 0;
6862     }
6863     return false;
6864   }));
6865 
6866   EXPECT_FALSE(verifyModule(*M, &errs()));
6867 }
6868 
6869 TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
6870   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6871   OpenMPIRBuilder OMPBuilder(*M);
6872   OMPBuilder.Config.IsTargetDevice = false;
6873   OMPBuilder.initialize();
6874   F->setName("func");
6875   IRBuilder<> Builder(BB);
6876   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6877     return Error::success();
6878   };
6879   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6880   IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
6881   Builder.SetInsertPoint(BodyBB);
6882   Value *IfCondition = Builder.CreateICmp(
6883       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
6884       ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
6885   OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
6886   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6887       OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
6888                             /*Tied=*/false, /*Final=*/nullptr, IfCondition);
6889   assert(AfterIP && "unexpected error");
6890   Builder.restoreIP(*AfterIP);
6891   OMPBuilder.finalize();
6892   Builder.CreateRetVoid();
6893 
6894   EXPECT_FALSE(verifyModule(*M, &errs()));
6895 
6896   CallInst *TaskAllocCall = dyn_cast<CallInst>(
6897       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
6898           ->user_back());
6899   ASSERT_NE(TaskAllocCall, nullptr);
6900 
6901   // Check the branching is based on the if condition argument.
6902   BranchInst *IfConditionBranchInst =
6903       dyn_cast<BranchInst>(TaskAllocCall->getParent()->getTerminator());
6904   ASSERT_NE(IfConditionBranchInst, nullptr);
6905   ASSERT_TRUE(IfConditionBranchInst->isConditional());
6906   EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition);
6907 
6908   // Check that the `__kmpc_omp_task` executes only in the then branch.
6909   CallInst *TaskCall = dyn_cast<CallInst>(
6910       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
6911           ->user_back());
6912   ASSERT_NE(TaskCall, nullptr);
6913   EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0));
6914 
6915   // Check that the OpenMP Runtime Functions specific to `if` clause execute
6916   // only in the else branch. Also check that the function call is between the
6917   // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls.
6918   CallInst *TaskBeginIfCall = dyn_cast<CallInst>(
6919       OMPBuilder
6920           .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0)
6921           ->user_back());
6922   CallInst *TaskCompleteCall = dyn_cast<CallInst>(
6923       OMPBuilder
6924           .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0)
6925           ->user_back());
6926   ASSERT_NE(TaskBeginIfCall, nullptr);
6927   ASSERT_NE(TaskCompleteCall, nullptr);
6928   Function *OulinedFn =
6929       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
6930   ASSERT_NE(OulinedFn, nullptr);
6931   CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back());
6932   ASSERT_NE(OulinedFnCall, nullptr);
6933   EXPECT_EQ(TaskBeginIfCall->getParent(),
6934             IfConditionBranchInst->getSuccessor(1));
6935 
6936   EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall);
6937   EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall);
6938 }
6939 
6940 TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
6941   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6942   OpenMPIRBuilder OMPBuilder(*M);
6943   OMPBuilder.initialize();
6944   F->setName("func");
6945   IRBuilder<> Builder(BB);
6946 
6947   AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
6948   AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
6949   Value *Val128 =
6950       Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
6951   Instruction *ThenTerm, *ElseTerm;
6952 
6953   Value *InternalStoreInst, *InternalLoad32, *InternalLoad128, *InternalIfCmp;
6954 
6955   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6956     Builder.restoreIP(AllocaIP);
6957     AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
6958                                                 "bodygen.alloca128");
6959 
6960     Builder.restoreIP(CodeGenIP);
6961     // Loading and storing captured pointer and values
6962     InternalStoreInst = Builder.CreateStore(Val128, Local128);
6963     InternalLoad32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
6964                                         "bodygen.load32");
6965 
6966     InternalLoad128 = Builder.CreateLoad(Local128->getAllocatedType(), Local128,
6967                                          "bodygen.local.load128");
6968     InternalIfCmp = Builder.CreateICmpNE(
6969         InternalLoad32,
6970         Builder.CreateTrunc(InternalLoad128, InternalLoad32->getType()));
6971     SplitBlockAndInsertIfThenElse(InternalIfCmp,
6972                                   CodeGenIP.getBlock()->getTerminator(),
6973                                   &ThenTerm, &ElseTerm);
6974     return Error::success();
6975   };
6976 
6977   BasicBlock *AllocaBB = Builder.GetInsertBlock();
6978   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
6979   OpenMPIRBuilder::LocationDescription Loc(
6980       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
6981   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTaskgroup(
6982       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB);
6983   assert(AfterIP && "unexpected error");
6984   Builder.restoreIP(*AfterIP);
6985   OMPBuilder.finalize();
6986   Builder.CreateRetVoid();
6987 
6988   EXPECT_FALSE(verifyModule(*M, &errs()));
6989 
6990   CallInst *TaskgroupCall = dyn_cast<CallInst>(
6991       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
6992           ->user_back());
6993   ASSERT_NE(TaskgroupCall, nullptr);
6994   CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
6995       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
6996           ->user_back());
6997   ASSERT_NE(EndTaskgroupCall, nullptr);
6998 
6999   // Verify the Ident argument
7000   GlobalVariable *Ident = cast<GlobalVariable>(TaskgroupCall->getArgOperand(0));
7001   ASSERT_NE(Ident, nullptr);
7002   EXPECT_TRUE(Ident->hasInitializer());
7003   Constant *Initializer = Ident->getInitializer();
7004   GlobalVariable *SrcStrGlob =
7005       cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
7006   ASSERT_NE(SrcStrGlob, nullptr);
7007   ConstantDataArray *SrcSrc =
7008       dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
7009   ASSERT_NE(SrcSrc, nullptr);
7010 
7011   // Verify the num_threads argument.
7012   CallInst *GTID = dyn_cast<CallInst>(TaskgroupCall->getArgOperand(1));
7013   ASSERT_NE(GTID, nullptr);
7014   EXPECT_EQ(GTID->arg_size(), 1U);
7015   EXPECT_EQ(GTID->getCalledFunction(), OMPBuilder.getOrCreateRuntimeFunctionPtr(
7016                                            OMPRTL___kmpc_global_thread_num));
7017 
7018   // Checking the general structure of the IR generated is same as expected.
7019   Instruction *GeneratedStoreInst = TaskgroupCall->getNextNonDebugInstruction();
7020   EXPECT_EQ(GeneratedStoreInst, InternalStoreInst);
7021   Instruction *GeneratedLoad32 =
7022       GeneratedStoreInst->getNextNonDebugInstruction();
7023   EXPECT_EQ(GeneratedLoad32, InternalLoad32);
7024   Instruction *GeneratedLoad128 = GeneratedLoad32->getNextNonDebugInstruction();
7025   EXPECT_EQ(GeneratedLoad128, InternalLoad128);
7026 
7027   // Checking the ordering because of the if statements and that
7028   // `__kmp_end_taskgroup` call is after the if branching.
7029   BasicBlock *RefOrder[] = {TaskgroupCall->getParent(), ThenTerm->getParent(),
7030                             ThenTerm->getSuccessor(0),
7031                             EndTaskgroupCall->getParent(),
7032                             ElseTerm->getParent()};
7033   verifyDFSOrder(F, RefOrder);
7034 }
7035 
7036 TEST_F(OpenMPIRBuilderTest, CreateTaskgroupWithTasks) {
7037   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
7038   OpenMPIRBuilder OMPBuilder(*M);
7039   OMPBuilder.Config.IsTargetDevice = false;
7040   OMPBuilder.initialize();
7041   F->setName("func");
7042   IRBuilder<> Builder(BB);
7043 
7044   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
7045     Builder.restoreIP(AllocaIP);
7046     AllocaInst *Alloca32 =
7047         Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, "bodygen.alloca32");
7048     AllocaInst *Alloca64 =
7049         Builder.CreateAlloca(Builder.getInt64Ty(), nullptr, "bodygen.alloca64");
7050     Builder.restoreIP(CodeGenIP);
7051     auto TaskBodyGenCB1 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
7052       Builder.restoreIP(CodeGenIP);
7053       LoadInst *LoadValue =
7054           Builder.CreateLoad(Alloca64->getAllocatedType(), Alloca64);
7055       Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt64(64));
7056       Builder.CreateStore(AddInst, Alloca64);
7057       return Error::success();
7058     };
7059     OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
7060     OpenMPIRBuilder::InsertPointOrErrorTy TaskIP1 =
7061         OMPBuilder.createTask(Loc, AllocaIP, TaskBodyGenCB1);
7062     assert(TaskIP1 && "unexpected error");
7063     Builder.restoreIP(*TaskIP1);
7064 
7065     auto TaskBodyGenCB2 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
7066       Builder.restoreIP(CodeGenIP);
7067       LoadInst *LoadValue =
7068           Builder.CreateLoad(Alloca32->getAllocatedType(), Alloca32);
7069       Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt32(32));
7070       Builder.CreateStore(AddInst, Alloca32);
7071       return Error::success();
7072     };
7073     OpenMPIRBuilder::LocationDescription Loc2(Builder.saveIP(), DL);
7074     OpenMPIRBuilder::InsertPointOrErrorTy TaskIP2 =
7075         OMPBuilder.createTask(Loc2, AllocaIP, TaskBodyGenCB2);
7076     assert(TaskIP2 && "unexpected error");
7077     Builder.restoreIP(*TaskIP2);
7078     return Error::success();
7079   };
7080 
7081   BasicBlock *AllocaBB = Builder.GetInsertBlock();
7082   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
7083   OpenMPIRBuilder::LocationDescription Loc(
7084       InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
7085   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTaskgroup(
7086       Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB);
7087   assert(AfterIP && "unexpected error");
7088   Builder.restoreIP(*AfterIP);
7089   OMPBuilder.finalize();
7090   Builder.CreateRetVoid();
7091 
7092   EXPECT_FALSE(verifyModule(*M, &errs()));
7093 
7094   CallInst *TaskgroupCall = dyn_cast<CallInst>(
7095       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
7096           ->user_back());
7097   ASSERT_NE(TaskgroupCall, nullptr);
7098   CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
7099       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
7100           ->user_back());
7101   ASSERT_NE(EndTaskgroupCall, nullptr);
7102 
7103   Function *TaskAllocFn =
7104       OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
7105   ASSERT_EQ(TaskAllocFn->getNumUses(), 2u);
7106 
7107   CallInst *FirstTaskAllocCall =
7108       dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin());
7109   CallInst *SecondTaskAllocCall =
7110       dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin()++);
7111   ASSERT_NE(FirstTaskAllocCall, nullptr);
7112   ASSERT_NE(SecondTaskAllocCall, nullptr);
7113 
7114   // Verify that the tasks have been generated in order and inside taskgroup
7115   // construct.
7116   BasicBlock *RefOrder[] = {
7117       TaskgroupCall->getParent(), FirstTaskAllocCall->getParent(),
7118       SecondTaskAllocCall->getParent(), EndTaskgroupCall->getParent()};
7119   verifyDFSOrder(F, RefOrder);
7120 }
7121 
7122 TEST_F(OpenMPIRBuilderTest, EmitOffloadingArraysArguments) {
7123   OpenMPIRBuilder OMPBuilder(*M);
7124   OMPBuilder.initialize();
7125 
7126   IRBuilder<> Builder(BB);
7127 
7128   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7129   OpenMPIRBuilder::TargetDataInfo Info(true, false);
7130 
7131   auto VoidPtrPtrTy = PointerType::getUnqual(Builder.getContext());
7132   auto Int64PtrTy = PointerType::getUnqual(Builder.getContext());
7133 
7134   Info.RTArgs.BasePointersArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7135   Info.RTArgs.PointersArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7136   Info.RTArgs.SizesArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7137   Info.RTArgs.MapTypesArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7138   Info.RTArgs.MapNamesArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7139   Info.RTArgs.MappersArray = ConstantPointerNull::get(Builder.getPtrTy(0));
7140   Info.NumberOfPtrs = 4;
7141   Info.EmitDebug = false;
7142   OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info, false);
7143 
7144   EXPECT_NE(RTArgs.BasePointersArray, nullptr);
7145   EXPECT_NE(RTArgs.PointersArray, nullptr);
7146   EXPECT_NE(RTArgs.SizesArray, nullptr);
7147   EXPECT_NE(RTArgs.MapTypesArray, nullptr);
7148   EXPECT_NE(RTArgs.MappersArray, nullptr);
7149   EXPECT_NE(RTArgs.MapNamesArray, nullptr);
7150   EXPECT_EQ(RTArgs.MapTypesArrayEnd, nullptr);
7151 
7152   EXPECT_EQ(RTArgs.BasePointersArray->getType(), VoidPtrPtrTy);
7153   EXPECT_EQ(RTArgs.PointersArray->getType(), VoidPtrPtrTy);
7154   EXPECT_EQ(RTArgs.SizesArray->getType(), Int64PtrTy);
7155   EXPECT_EQ(RTArgs.MapTypesArray->getType(), Int64PtrTy);
7156   EXPECT_EQ(RTArgs.MappersArray->getType(), VoidPtrPtrTy);
7157   EXPECT_EQ(RTArgs.MapNamesArray->getType(), VoidPtrPtrTy);
7158 }
7159 
7160 TEST_F(OpenMPIRBuilderTest, OffloadEntriesInfoManager) {
7161   OpenMPIRBuilder OMPBuilder(*M);
7162   OMPBuilder.setConfig(
7163       OpenMPIRBuilderConfig(true, false, false, false, false, false, false));
7164   OffloadEntriesInfoManager &InfoManager = OMPBuilder.OffloadInfoManager;
7165   TargetRegionEntryInfo EntryInfo("parent", 1, 2, 4, 0);
7166   InfoManager.initializeTargetRegionEntryInfo(EntryInfo, 0);
7167   EXPECT_TRUE(InfoManager.hasTargetRegionEntryInfo(EntryInfo));
7168   InfoManager.initializeDeviceGlobalVarEntryInfo(
7169       "gvar", OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo, 0);
7170   InfoManager.registerTargetRegionEntryInfo(
7171       EntryInfo, nullptr, nullptr,
7172       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
7173   InfoManager.registerDeviceGlobalVarEntryInfo(
7174       "gvar", 0x0, 8, OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo,
7175       GlobalValue::WeakAnyLinkage);
7176   EXPECT_TRUE(InfoManager.hasDeviceGlobalVarEntryInfo("gvar"));
7177 }
7178 
7179 // Tests both registerTargetGlobalVariable and getAddrOfDeclareTargetVar as they
7180 // call each other (recursively in some cases). The test case test these
7181 // functions by utilising them for host code generation for declare target
7182 // global variables
7183 TEST_F(OpenMPIRBuilderTest, registerTargetGlobalVariable) {
7184   OpenMPIRBuilder OMPBuilder(*M);
7185   OMPBuilder.initialize();
7186   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
7187   OMPBuilder.setConfig(Config);
7188 
7189   std::vector<llvm::Triple> TargetTriple;
7190   TargetTriple.emplace_back("amdgcn-amd-amdhsa");
7191 
7192   TargetRegionEntryInfo EntryInfo("", 42, 4711, 17);
7193   std::vector<GlobalVariable *> RefsGathered;
7194 
7195   std::vector<Constant *> Globals;
7196   auto *IntTy = Type::getInt32Ty(Ctx);
7197   for (int I = 0; I < 2; ++I) {
7198     Globals.push_back(M->getOrInsertGlobal(
7199         "test_data_int_" + std::to_string(I), IntTy, [&]() -> GlobalVariable * {
7200           return new GlobalVariable(
7201               *M, IntTy, false, GlobalValue::LinkageTypes::WeakAnyLinkage,
7202               ConstantInt::get(IntTy, I), "test_data_int_" + std::to_string(I));
7203         }));
7204   }
7205 
7206   OMPBuilder.registerTargetGlobalVariable(
7207       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo,
7208       OffloadEntriesInfoManager::OMPTargetDeviceClauseAny, false, true,
7209       EntryInfo, Globals[0]->getName(), RefsGathered, false, TargetTriple,
7210       nullptr, nullptr, Globals[0]->getType(), Globals[0]);
7211 
7212   OMPBuilder.registerTargetGlobalVariable(
7213       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink,
7214       OffloadEntriesInfoManager::OMPTargetDeviceClauseAny, false, true,
7215       EntryInfo, Globals[1]->getName(), RefsGathered, false, TargetTriple,
7216       nullptr, nullptr, Globals[1]->getType(), Globals[1]);
7217 
7218   llvm::OpenMPIRBuilder::EmitMetadataErrorReportFunctionTy &&ErrorReportfn =
7219       [](llvm::OpenMPIRBuilder::EmitMetadataErrorKind Kind,
7220          const llvm::TargetRegionEntryInfo &EntryInfo) -> void {
7221     // If this is invoked, then we want to emit an error, even if it is not
7222     // neccesarily the most readable, as something has went wrong. The
7223     // test-suite unfortunately eats up all error output
7224     ASSERT_EQ(Kind, Kind);
7225   };
7226 
7227   OMPBuilder.createOffloadEntriesAndInfoMetadata(ErrorReportfn);
7228 
7229   // Clauses for data_int_0 with To + Any clauses for the host
7230   std::vector<GlobalVariable *> OffloadEntries;
7231   OffloadEntries.push_back(M->getNamedGlobal(".offloading.entry_name"));
7232   OffloadEntries.push_back(
7233       M->getNamedGlobal(".offloading.entry.test_data_int_0"));
7234 
7235   // Clauses for data_int_1 with Link + Any clauses for the host
7236   OffloadEntries.push_back(
7237       M->getNamedGlobal("test_data_int_1_decl_tgt_ref_ptr"));
7238   OffloadEntries.push_back(M->getNamedGlobal(".offloading.entry_name.1"));
7239   OffloadEntries.push_back(
7240       M->getNamedGlobal(".offloading.entry.test_data_int_1_decl_tgt_ref_ptr"));
7241 
7242   for (unsigned I = 0; I < OffloadEntries.size(); ++I)
7243     EXPECT_NE(OffloadEntries[I], nullptr);
7244 
7245   // Metadata generated for the host offload module
7246   NamedMDNode *OffloadMetadata = M->getNamedMetadata("omp_offload.info");
7247   ASSERT_THAT(OffloadMetadata, testing::NotNull());
7248   StringRef Nodes[2] = {
7249       cast<MDString>(OffloadMetadata->getOperand(0)->getOperand(1))
7250           ->getString(),
7251       cast<MDString>(OffloadMetadata->getOperand(1)->getOperand(1))
7252           ->getString()};
7253   EXPECT_THAT(
7254       Nodes, testing::UnorderedElementsAre("test_data_int_0",
7255                                            "test_data_int_1_decl_tgt_ref_ptr"));
7256 }
7257 
7258 TEST_F(OpenMPIRBuilderTest, createGPUOffloadEntry) {
7259   OpenMPIRBuilder OMPBuilder(*M);
7260   OMPBuilder.initialize();
7261   OpenMPIRBuilderConfig Config(/* IsTargetDevice = */ true,
7262                                /* IsGPU = */ true,
7263                                /* OpenMPOffloadMandatory = */ false,
7264                                /* HasRequiresReverseOffload = */ false,
7265                                /* HasRequiresUnifiedAddress = */ false,
7266                                /* HasRequiresUnifiedSharedMemory = */ false,
7267                                /* HasRequiresDynamicAllocators = */ false);
7268   OMPBuilder.setConfig(Config);
7269 
7270   FunctionCallee FnTypeAndCallee =
7271       M->getOrInsertFunction("test_kernel", Type::getVoidTy(Ctx));
7272 
7273   auto *Fn = cast<Function>(FnTypeAndCallee.getCallee());
7274   OMPBuilder.createOffloadEntry(/* ID = */ nullptr, Fn,
7275                                 /* Size = */ 0,
7276                                 /* Flags = */ 0, GlobalValue::WeakAnyLinkage);
7277 
7278   // Check nvvm.annotations only created for GPU kernels
7279   NamedMDNode *MD = M->getNamedMetadata("nvvm.annotations");
7280   EXPECT_NE(MD, nullptr);
7281   EXPECT_EQ(MD->getNumOperands(), 1u);
7282 
7283   MDNode *Annotations = MD->getOperand(0);
7284   EXPECT_EQ(Annotations->getNumOperands(), 3u);
7285 
7286   Constant *ConstVal =
7287       dyn_cast<ConstantAsMetadata>(Annotations->getOperand(0))->getValue();
7288   EXPECT_TRUE(isa<Function>(Fn));
7289   EXPECT_EQ(ConstVal, cast<Function>(Fn));
7290 
7291   EXPECT_TRUE(Annotations->getOperand(1).equalsStr("kernel"));
7292 
7293   EXPECT_TRUE(mdconst::hasa<ConstantInt>(Annotations->getOperand(2)));
7294   APInt IntVal =
7295       mdconst::extract<ConstantInt>(Annotations->getOperand(2))->getValue();
7296   EXPECT_EQ(IntVal, 1);
7297 
7298   // Check kernel attributes
7299   EXPECT_TRUE(Fn->hasFnAttribute("kernel"));
7300   EXPECT_TRUE(Fn->hasFnAttribute(Attribute::MustProgress));
7301 }
7302 
7303 } // namespace
7304