xref: /llvm-project/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp (revision 32b38d248fd3c75abc5c86ab6677b6cb08a703cc)
1 //===- AssumeBundleQueriesTest.cpp ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/AssumeBundleQueries.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/IntrinsicInst.h"
14 #include "llvm/Support/Regex.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
18 #include "gtest/gtest.h"
19 #include <random>
20 
21 using namespace llvm;
22 
23 namespace llvm {
24 extern cl::opt<bool> ShouldPreserveAllAttributes;
25 extern cl::opt<bool> EnableKnowledgeRetention;
26 } // namespace llvm
27 
28 static void RunTest(
29     StringRef Head, StringRef Tail,
30     std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
31         &Tests) {
32   for (auto &Elem : Tests) {
33     std::string IR;
34     IR.append(Head.begin(), Head.end());
35     IR.append(Elem.first.begin(), Elem.first.end());
36     IR.append(Tail.begin(), Tail.end());
37     LLVMContext C;
38     SMDiagnostic Err;
39     std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
40     if (!Mod)
41       Err.print("AssumeQueryAPI", errs());
42     Elem.second(&*(Mod->getFunction("test")->begin()->begin()));
43   }
44 }
45 
46 bool hasMatchesExactlyAttributes(AssumeInst *Assume, Value *WasOn,
47                                  StringRef AttrToMatch) {
48   Regex Reg(AttrToMatch);
49   SmallVector<StringRef, 1> Matches;
50   for (StringRef Attr : {
51 #define GET_ATTR_NAMES
52 #define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
53 #include "llvm/IR/Attributes.inc"
54        }) {
55     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
56     if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr))
57       return false;
58   }
59   return true;
60 }
61 
62 bool hasTheRightValue(AssumeInst *Assume, Value *WasOn,
63                       Attribute::AttrKind Kind, unsigned Value) {
64   uint64_t ArgVal = 0;
65   if (!hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal))
66     return false;
67   if (ArgVal != Value)
68     return false;
69   return true;
70 }
71 
72 TEST(AssumeQueryAPI, hasAttributeInAssume) {
73   EnableKnowledgeRetention.setValue(true);
74   StringRef Head =
75       "declare void @llvm.assume(i1)\n"
76       "declare void @func(i32*, i32*, i32*)\n"
77       "declare void @func1(i32*, i32*, i32*, i32*)\n"
78       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
79       "\"less-precise-fpmad\" willreturn norecurse\n"
80       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
81   StringRef Tail = "ret void\n"
82                    "}";
83   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
84       Tests;
85   Tests.push_back(std::make_pair(
86       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
87       "8 noalias %P1, i32* align 8 noundef %P2)\n",
88       [](Instruction *I) {
89         auto *Assume = buildAssumeFromInst(I);
90         Assume->insertBefore(I);
91         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
92                                        "(nonnull|align|dereferenceable)"));
93         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
94                                        "()"));
95         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
96                                        "(align|noundef)"));
97         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
98                                      Attribute::AttrKind::Dereferenceable, 16));
99         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
100                                      Attribute::AttrKind::Alignment, 4));
101         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
102                                      Attribute::AttrKind::Alignment, 4));
103       }));
104   Tests.push_back(std::make_pair(
105       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
106       "nonnull "
107       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
108       "dereferenceable(4) "
109       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
110       [](Instruction *I) {
111         auto *Assume = buildAssumeFromInst(I);
112         Assume->insertBefore(I);
113         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
114                                        "(nonnull|align|dereferenceable)"));
115         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
116                                        "(nonnull|align|dereferenceable)"));
117         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
118                                        "(nonnull|align|dereferenceable)"));
119         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
120                                        "(nonnull|align|dereferenceable)"));
121         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
122                                      Attribute::AttrKind::Dereferenceable, 48));
123         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
124                                      Attribute::AttrKind::Alignment, 64));
125         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
126                                      Attribute::AttrKind::Alignment, 64));
127       }));
128   Tests.push_back(std::make_pair(
129       "call void @func_many(i32* align 8 noundef %P1) cold\n", [](Instruction *I) {
130         ShouldPreserveAllAttributes.setValue(true);
131         auto *Assume = buildAssumeFromInst(I);
132         Assume->insertBefore(I);
133         ASSERT_TRUE(hasMatchesExactlyAttributes(
134             Assume, nullptr,
135             "(align|nounwind|norecurse|noundef|willreturn|cold)"));
136         ShouldPreserveAllAttributes.setValue(false);
137       }));
138   Tests.push_back(
139       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
140         auto *Assume = cast<AssumeInst>(I);
141         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, nullptr, ""));
142       }));
143   Tests.push_back(std::make_pair(
144       "call void @func1(i32* readnone align 32 "
145       "dereferenceable(48) noalias %P, i32* "
146       "align 8 dereferenceable(28) %P1, i32* align 64 "
147       "dereferenceable(4) "
148       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
149       [](Instruction *I) {
150         auto *Assume = buildAssumeFromInst(I);
151         Assume->insertBefore(I);
152         ASSERT_TRUE(hasMatchesExactlyAttributes(
153             Assume, I->getOperand(0),
154             "(align|dereferenceable)"));
155         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
156                                        "(align|dereferenceable)"));
157         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
158                                        "(align|dereferenceable)"));
159         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
160                                        "(nonnull|align|dereferenceable)"));
161         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
162                                      Attribute::AttrKind::Alignment, 32));
163         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
164                                      Attribute::AttrKind::Dereferenceable, 48));
165         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
166                                      Attribute::AttrKind::Dereferenceable, 28));
167         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
168                                      Attribute::AttrKind::Alignment, 8));
169         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
170                                      Attribute::AttrKind::Alignment, 64));
171         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
172                                      Attribute::AttrKind::Dereferenceable, 4));
173         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
174                                      Attribute::AttrKind::Alignment, 16));
175         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
176                                      Attribute::AttrKind::Dereferenceable, 12));
177       }));
178 
179   Tests.push_back(std::make_pair(
180       "call void @func1(i32* readnone align 32 "
181       "dereferenceable(48) noalias %P, i32* "
182       "align 8 dereferenceable(28) %P1, i32* align 64 "
183       "dereferenceable(4) "
184       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
185       [](Instruction *I) {
186         auto *Assume = buildAssumeFromInst(I);
187         Assume->insertBefore(I);
188         I->getOperand(1)->dropDroppableUses();
189         I->getOperand(2)->dropDroppableUses();
190         I->getOperand(3)->dropDroppableUses();
191         ASSERT_TRUE(hasMatchesExactlyAttributes(
192             Assume, I->getOperand(0),
193             "(align|dereferenceable)"));
194         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
195                                        ""));
196         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
197                                        ""));
198         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
199                                        ""));
200         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
201                                      Attribute::AttrKind::Alignment, 32));
202         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
203                                      Attribute::AttrKind::Dereferenceable, 48));
204       }));
205   Tests.push_back(std::make_pair(
206       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
207       "8 noalias %P1, i32* %P1)\n",
208       [](Instruction *I) {
209         auto *Assume = buildAssumeFromInst(I);
210         Assume->insertBefore(I);
211         Value *New = I->getFunction()->getArg(3);
212         Value *Old = I->getOperand(0);
213         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New, ""));
214         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old,
215                                        "(nonnull|align|dereferenceable)"));
216         Old->replaceAllUsesWith(New);
217         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New,
218                                        "(nonnull|align|dereferenceable)"));
219         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old, ""));
220       }));
221   RunTest(Head, Tail, Tests);
222 }
223 
224 static bool FindExactlyAttributes(RetainedKnowledgeMap &Map, Value *WasOn,
225                                  StringRef AttrToMatch) {
226   Regex Reg(AttrToMatch);
227   SmallVector<StringRef, 1> Matches;
228   for (StringRef Attr : {
229 #define GET_ATTR_NAMES
230 #define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
231 #include "llvm/IR/Attributes.inc"
232        }) {
233     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
234 
235     if (ShouldHaveAttr != (Map.find(RetainedKnowledgeKey{WasOn, Attribute::getAttrKindFromName(Attr)}) != Map.end()))
236       return false;
237   }
238   return true;
239 }
240 
241 static bool MapHasRightValue(RetainedKnowledgeMap &Map, AssumeInst *II,
242                              RetainedKnowledgeKey Key, MinMax MM) {
243   auto LookupIt = Map.find(Key);
244   return (LookupIt != Map.end()) && (LookupIt->second[II].Min == MM.Min) &&
245          (LookupIt->second[II].Max == MM.Max);
246 }
247 
248 TEST(AssumeQueryAPI, fillMapFromAssume) {
249   EnableKnowledgeRetention.setValue(true);
250   StringRef Head =
251       "declare void @llvm.assume(i1)\n"
252       "declare void @func(i32*, i32*, i32*)\n"
253       "declare void @func1(i32*, i32*, i32*, i32*)\n"
254       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
255       "\"less-precise-fpmad\" willreturn norecurse\n"
256       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
257   StringRef Tail = "ret void\n"
258                    "}";
259   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
260       Tests;
261   Tests.push_back(std::make_pair(
262       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
263       "8 noalias %P1, i32* align 8 dereferenceable(8) %P2)\n",
264       [](Instruction *I) {
265         auto *Assume = buildAssumeFromInst(I);
266         Assume->insertBefore(I);
267 
268         RetainedKnowledgeMap Map;
269         fillMapFromAssume(*Assume, Map);
270         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
271                                        "(nonnull|align|dereferenceable)"));
272         ASSERT_FALSE(FindExactlyAttributes(Map, I->getOperand(1),
273                                        "(align)"));
274         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
275                                        "(align|dereferenceable)"));
276         ASSERT_TRUE(MapHasRightValue(
277             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {16, 16}));
278         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
279                                {4, 4}));
280         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
281                                {4, 4}));
282       }));
283   Tests.push_back(std::make_pair(
284       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
285       "nonnull "
286       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
287       "dereferenceable(4) "
288       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
289       [](Instruction *I) {
290         auto *Assume = buildAssumeFromInst(I);
291         Assume->insertBefore(I);
292 
293         RetainedKnowledgeMap Map;
294         fillMapFromAssume(*Assume, Map);
295 
296         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
297                                        "(nonnull|align|dereferenceable)"));
298         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
299                                        "(nonnull|align|dereferenceable)"));
300         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
301                                        "(nonnull|align|dereferenceable)"));
302         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
303                                        "(nonnull|align|dereferenceable)"));
304         ASSERT_TRUE(MapHasRightValue(
305             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable},
306             {48, 48}));
307         ASSERT_TRUE(MapHasRightValue(
308             Map, Assume, {I->getOperand(0), Attribute::Alignment}, {64, 64}));
309       }));
310   Tests.push_back(std::make_pair(
311       "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) {
312         ShouldPreserveAllAttributes.setValue(true);
313         auto *Assume = buildAssumeFromInst(I);
314         Assume->insertBefore(I);
315 
316         RetainedKnowledgeMap Map;
317         fillMapFromAssume(*Assume, Map);
318 
319         ASSERT_TRUE(FindExactlyAttributes(
320             Map, nullptr, "(nounwind|norecurse|willreturn|cold)"));
321         ShouldPreserveAllAttributes.setValue(false);
322       }));
323   Tests.push_back(
324       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
325         RetainedKnowledgeMap Map;
326         fillMapFromAssume(*cast<AssumeInst>(I), Map);
327 
328         ASSERT_TRUE(FindExactlyAttributes(Map, nullptr, ""));
329         ASSERT_TRUE(Map.empty());
330       }));
331   Tests.push_back(std::make_pair(
332       "call void @func1(i32* readnone align 32 "
333       "dereferenceable(48) noalias %P, i32* "
334       "align 8 dereferenceable(28) %P1, i32* align 64 "
335       "dereferenceable(4) "
336       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
337       [](Instruction *I) {
338         auto *Assume = buildAssumeFromInst(I);
339         Assume->insertBefore(I);
340 
341         RetainedKnowledgeMap Map;
342         fillMapFromAssume(*Assume, Map);
343 
344         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
345                                     "(align|dereferenceable)"));
346         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
347                                     "(align|dereferenceable)"));
348         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
349                                        "(align|dereferenceable)"));
350         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
351                                        "(nonnull|align|dereferenceable)"));
352         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
353                                {32, 32}));
354         ASSERT_TRUE(MapHasRightValue(
355             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {48, 48}));
356         ASSERT_TRUE(MapHasRightValue(
357             Map, Assume, {I->getOperand(1), Attribute::Dereferenceable}, {28, 28}));
358         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(1), Attribute::Alignment},
359                                {8, 8}));
360         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(2), Attribute::Alignment},
361                                {64, 64}));
362         ASSERT_TRUE(MapHasRightValue(
363             Map, Assume, {I->getOperand(2), Attribute::Dereferenceable}, {4, 4}));
364         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(3), Attribute::Alignment},
365                                {16, 16}));
366         ASSERT_TRUE(MapHasRightValue(
367             Map, Assume, {I->getOperand(3), Attribute::Dereferenceable}, {12, 12}));
368       }));
369 
370   /// Keep this test last as it modifies the function.
371   Tests.push_back(std::make_pair(
372       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
373       "8 noalias %P1, i32* %P2)\n",
374       [](Instruction *I) {
375         auto *Assume = buildAssumeFromInst(I);
376         Assume->insertBefore(I);
377 
378         RetainedKnowledgeMap Map;
379         fillMapFromAssume(*Assume, Map);
380 
381         Value *New = I->getFunction()->getArg(3);
382         Value *Old = I->getOperand(0);
383         ASSERT_TRUE(FindExactlyAttributes(Map, New, ""));
384         ASSERT_TRUE(FindExactlyAttributes(Map, Old,
385                                        "(nonnull|align|dereferenceable)"));
386         Old->replaceAllUsesWith(New);
387         Map.clear();
388         fillMapFromAssume(*Assume, Map);
389         ASSERT_TRUE(FindExactlyAttributes(Map, New,
390                                        "(nonnull|align|dereferenceable)"));
391         ASSERT_TRUE(FindExactlyAttributes(Map, Old, ""));
392       }));
393   Tests.push_back(std::make_pair(
394       "call void @llvm.assume(i1 true) [\"align\"(i8* undef, i32 undef)]",
395       [](Instruction *I) {
396         // Don't crash but don't learn from undef.
397         RetainedKnowledgeMap Map;
398         fillMapFromAssume(*cast<AssumeInst>(I), Map);
399 
400         ASSERT_TRUE(Map.empty());
401       }));
402   RunTest(Head, Tail, Tests);
403 }
404 
405 static void RunRandTest(uint64_t Seed, int Size, int MinCount, int MaxCount,
406                         unsigned MaxValue) {
407   LLVMContext C;
408   SMDiagnostic Err;
409 
410   std::random_device dev;
411   std::mt19937 Rng(Seed);
412   std::uniform_int_distribution<int> DistCount(MinCount, MaxCount);
413   std::uniform_int_distribution<unsigned> DistValue(0, MaxValue);
414   std::uniform_int_distribution<unsigned> DistAttr(0,
415                                                    Attribute::EndAttrKinds - 1);
416 
417   std::unique_ptr<Module> Mod = std::make_unique<Module>("AssumeQueryAPI", C);
418   if (!Mod)
419     Err.print("AssumeQueryAPI", errs());
420 
421   std::vector<Type *> TypeArgs;
422   for (int i = 0; i < (Size * 2); i++)
423     TypeArgs.push_back(Type::getInt32PtrTy(C));
424   FunctionType *FuncType =
425       FunctionType::get(Type::getVoidTy(C), TypeArgs, false);
426 
427   Function *F =
428       Function::Create(FuncType, GlobalValue::ExternalLinkage, "test", &*Mod);
429   BasicBlock *BB = BasicBlock::Create(C);
430   BB->insertInto(F);
431   Instruction *Ret = ReturnInst::Create(C);
432   Ret->insertInto(BB, BB->begin());
433   Function *FnAssume = Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume);
434 
435   std::vector<Argument *> ShuffledArgs;
436   BitVector HasArg;
437   for (auto &Arg : F->args()) {
438     ShuffledArgs.push_back(&Arg);
439     HasArg.push_back(false);
440   }
441 
442   std::shuffle(ShuffledArgs.begin(), ShuffledArgs.end(), Rng);
443 
444   std::vector<OperandBundleDef> OpBundle;
445   OpBundle.reserve(Size);
446   std::vector<Value *> Args;
447   Args.reserve(2);
448   for (int i = 0; i < Size; i++) {
449     int count = DistCount(Rng);
450     int value = DistValue(Rng);
451     int attr = DistAttr(Rng);
452     std::string str;
453     raw_string_ostream ss(str);
454     ss << Attribute::getNameFromAttrKind(
455         static_cast<Attribute::AttrKind>(attr));
456     Args.clear();
457 
458     if (count > 0) {
459       Args.push_back(ShuffledArgs[i]);
460       HasArg[i] = true;
461     }
462     if (count > 1)
463       Args.push_back(ConstantInt::get(Type::getInt32Ty(C), value));
464 
465     OpBundle.push_back(OperandBundleDef{ss.str().c_str(), std::move(Args)});
466   }
467 
468   auto *Assume = cast<AssumeInst>(CallInst::Create(
469       FnAssume, ArrayRef<Value *>({ConstantInt::getTrue(C)}), OpBundle));
470   Assume->insertBefore(&F->begin()->front());
471   RetainedKnowledgeMap Map;
472   fillMapFromAssume(*Assume, Map);
473   for (int i = 0; i < (Size * 2); i++) {
474     if (!HasArg[i])
475       continue;
476     RetainedKnowledge K =
477         getKnowledgeFromUseInAssume(&*ShuffledArgs[i]->use_begin());
478     auto LookupIt = Map.find(RetainedKnowledgeKey{K.WasOn, K.AttrKind});
479     ASSERT_TRUE(LookupIt != Map.end());
480     MinMax MM = LookupIt->second[Assume];
481     ASSERT_TRUE(MM.Min == MM.Max);
482     ASSERT_TRUE(MM.Min == K.ArgValue);
483   }
484 }
485 
486 TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
487   // // For Fuzzing
488   // std::random_device dev;
489   // std::mt19937 Rng(dev());
490   // while (true) {
491   //   unsigned Seed = Rng();
492   //   dbgs() << Seed << "\n";
493   //   RunRandTest(Seed, 100000, 0, 2, 100);
494   // }
495   RunRandTest(23456, 4, 0, 2, 100);
496   RunRandTest(560987, 25, -3, 2, 100);
497 
498   // Large bundles can lead to special cases. this is why this test is soo
499   // large.
500   RunRandTest(9876789, 100000, -0, 7, 100);
501 }
502 
503 TEST(AssumeQueryAPI, AssumptionCache) {
504   LLVMContext C;
505   SMDiagnostic Err;
506   std::unique_ptr<Module> Mod = parseAssemblyString(
507       "declare void @llvm.assume(i1)\n"
508       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
509       "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
510       "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
511       "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
512       "\"dereferenceable\"(i32* %P, i32 4)]\n"
513       "ret void\n}\n",
514       Err, C);
515   if (!Mod)
516     Err.print("AssumeQueryAPI", errs());
517   Function *F = Mod->getFunction("test");
518   BasicBlock::iterator First = F->begin()->begin();
519   BasicBlock::iterator Second = F->begin()->begin();
520   Second++;
521   AssumptionCache AC(*F);
522   auto AR = AC.assumptionsFor(F->getArg(3));
523   ASSERT_EQ(AR.size(), 0u);
524   AR = AC.assumptionsFor(F->getArg(1));
525   ASSERT_EQ(AR.size(), 1u);
526   ASSERT_EQ(AR[0].Index, 0u);
527   ASSERT_EQ(AR[0].Assume, &*Second);
528   AR = AC.assumptionsFor(F->getArg(2));
529   ASSERT_EQ(AR.size(), 1u);
530   ASSERT_EQ(AR[0].Index, 1u);
531   ASSERT_EQ(AR[0].Assume, &*First);
532   AR = AC.assumptionsFor(F->getArg(0));
533   ASSERT_EQ(AR.size(), 3u);
534   llvm::sort(AR,
535              [](const auto &L, const auto &R) { return L.Index < R.Index; });
536   ASSERT_EQ(AR[0].Assume, &*First);
537   ASSERT_EQ(AR[0].Index, 0u);
538   ASSERT_EQ(AR[1].Assume, &*Second);
539   ASSERT_EQ(AR[1].Index, 1u);
540   ASSERT_EQ(AR[2].Assume, &*First);
541   ASSERT_EQ(AR[2].Index, 2u);
542   AR = AC.assumptionsFor(F->getArg(4));
543   ASSERT_EQ(AR.size(), 1u);
544   ASSERT_EQ(AR[0].Assume, &*Second);
545   ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
546   AC.unregisterAssumption(cast<AssumeInst>(&*Second));
547   AR = AC.assumptionsFor(F->getArg(1));
548   ASSERT_EQ(AR.size(), 0u);
549   AR = AC.assumptionsFor(F->getArg(0));
550   ASSERT_EQ(AR.size(), 3u);
551   llvm::sort(AR,
552              [](const auto &L, const auto &R) { return L.Index < R.Index; });
553   ASSERT_EQ(AR[0].Assume, &*First);
554   ASSERT_EQ(AR[0].Index, 0u);
555   ASSERT_EQ(AR[1].Assume, nullptr);
556   ASSERT_EQ(AR[1].Index, 1u);
557   ASSERT_EQ(AR[2].Assume, &*First);
558   ASSERT_EQ(AR[2].Index, 2u);
559   AR = AC.assumptionsFor(F->getArg(2));
560   ASSERT_EQ(AR.size(), 1u);
561   ASSERT_EQ(AR[0].Index, 1u);
562   ASSERT_EQ(AR[0].Assume, &*First);
563 }
564 
565 TEST(AssumeQueryAPI, Alignment) {
566   LLVMContext C;
567   SMDiagnostic Err;
568   std::unique_ptr<Module> Mod = parseAssemblyString(
569       "declare void @llvm.assume(i1)\n"
570       "define void @test(i32* %P, i32* %P1, i32* %P2, i32 %I3, i1 %B) {\n"
571       "call void @llvm.assume(i1 true) [\"align\"(i32* %P, i32 8, i32 %I3)]\n"
572       "call void @llvm.assume(i1 true) [\"align\"(i32* %P1, i32 %I3, i32 "
573       "%I3)]\n"
574       "call void @llvm.assume(i1 true) [\"align\"(i32* %P2, i32 16, i32 8)]\n"
575       "ret void\n}\n",
576       Err, C);
577   if (!Mod)
578     Err.print("AssumeQueryAPI", errs());
579 
580   Function *F = Mod->getFunction("test");
581   BasicBlock::iterator Start = F->begin()->begin();
582   AssumeInst *II;
583   RetainedKnowledge RK;
584   II = cast<AssumeInst>(&*Start);
585   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
586   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
587   ASSERT_EQ(RK.WasOn, F->getArg(0));
588   ASSERT_EQ(RK.ArgValue, 1u);
589   Start++;
590   II = cast<AssumeInst>(&*Start);
591   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
592   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
593   ASSERT_EQ(RK.WasOn, F->getArg(1));
594   ASSERT_EQ(RK.ArgValue, 1u);
595   Start++;
596   II = cast<AssumeInst>(&*Start);
597   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
598   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
599   ASSERT_EQ(RK.WasOn, F->getArg(2));
600   ASSERT_EQ(RK.ArgValue, 8u);
601 }
602