xref: /llvm-project/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp (revision 52b48a70d3752f9db36ddcfd26d0451c009b19fc)
1 //===- AssumeBundleQueriesTest.cpp ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumeBundleQueries.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/IntrinsicInst.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/Support/CommandLine.h"
16 #include "llvm/Support/Regex.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
19 #include "gtest/gtest.h"
20 #include <random>
21 
22 using namespace llvm;
23 
24 namespace llvm {
25 extern cl::opt<bool> ShouldPreserveAllAttributes;
26 } // namespace llvm
27 
28 static void RunTest(
29     StringRef Head, StringRef Tail,
30     std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
31         &Tests) {
32   for (auto &Elem : Tests) {
33     std::string IR;
34     IR.append(Head.begin(), Head.end());
35     IR.append(Elem.first.begin(), Elem.first.end());
36     IR.append(Tail.begin(), Tail.end());
37     LLVMContext C;
38     SMDiagnostic Err;
39     std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
40     if (!Mod)
41       Err.print("AssumeQueryAPI", errs());
42     Elem.second(&*(Mod->getFunction("test")->begin()->begin()));
43   }
44 }
45 
46 bool hasMatchesExactlyAttributes(AssumeInst *Assume, Value *WasOn,
47                                  StringRef AttrToMatch) {
48   Regex Reg(AttrToMatch);
49   SmallVector<StringRef, 1> Matches;
50   for (StringRef Attr : {
51 #define GET_ATTR_NAMES
52 #define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
53 #include "llvm/IR/Attributes.inc"
54        }) {
55     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
56     if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr))
57       return false;
58   }
59   return true;
60 }
61 
62 bool hasTheRightValue(AssumeInst *Assume, Value *WasOn,
63                       Attribute::AttrKind Kind, unsigned Value) {
64   uint64_t ArgVal = 0;
65   if (!hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal))
66     return false;
67   if (ArgVal != Value)
68     return false;
69   return true;
70 }
71 
72 TEST(AssumeQueryAPI, hasAttributeInAssume) {
73   EnableKnowledgeRetention.setValue(true);
74   StringRef Head =
75       "declare void @llvm.assume(i1)\n"
76       "declare void @func(i32*, i32*, i32*)\n"
77       "declare void @func1(i32*, i32*, i32*, i32*)\n"
78       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
79       "\"less-precise-fpmad\" willreturn norecurse\n"
80       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
81   StringRef Tail = "ret void\n"
82                    "}";
83   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
84       Tests;
85   Tests.push_back(std::make_pair(
86       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
87       "8 noalias %P1, i32* align 8 noundef %P2)\n",
88       [](Instruction *I) {
89         auto *Assume = buildAssumeFromInst(I);
90         Assume->insertBefore(I);
91         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
92                                        "(nonnull|align|dereferenceable)"));
93         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
94                                        "()"));
95         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
96                                        "(align|noundef)"));
97         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
98                                      Attribute::AttrKind::Dereferenceable, 16));
99         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
100                                      Attribute::AttrKind::Alignment, 4));
101         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
102                                      Attribute::AttrKind::Alignment, 4));
103       }));
104   Tests.push_back(std::make_pair(
105       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
106       "nonnull "
107       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
108       "dereferenceable(4) "
109       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
110       [](Instruction *I) {
111         auto *Assume = buildAssumeFromInst(I);
112         Assume->insertBefore(I);
113         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
114                                        "(nonnull|align|dereferenceable)"));
115         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
116                                        "(nonnull|align|dereferenceable)"));
117         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
118                                        "(nonnull|align|dereferenceable)"));
119         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
120                                        "(nonnull|align|dereferenceable)"));
121         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
122                                      Attribute::AttrKind::Dereferenceable, 48));
123         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
124                                      Attribute::AttrKind::Alignment, 64));
125         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
126                                      Attribute::AttrKind::Alignment, 64));
127       }));
128   Tests.push_back(std::make_pair(
129       "call void @func_many(i32* align 8 noundef %P1) cold\n", [](Instruction *I) {
130         ShouldPreserveAllAttributes.setValue(true);
131         auto *Assume = buildAssumeFromInst(I);
132         Assume->insertBefore(I);
133         ASSERT_TRUE(hasMatchesExactlyAttributes(
134             Assume, nullptr,
135             "(align|nounwind|norecurse|noundef|willreturn|cold)"));
136         ShouldPreserveAllAttributes.setValue(false);
137       }));
138   Tests.push_back(
139       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
140         auto *Assume = cast<AssumeInst>(I);
141         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, nullptr, ""));
142       }));
143   Tests.push_back(std::make_pair(
144       "call void @func1(i32* readnone align 32 "
145       "dereferenceable(48) noalias %P, i32* "
146       "align 8 dereferenceable(28) %P1, i32* align 64 "
147       "dereferenceable(4) "
148       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
149       [](Instruction *I) {
150         auto *Assume = buildAssumeFromInst(I);
151         Assume->insertBefore(I);
152         ASSERT_TRUE(hasMatchesExactlyAttributes(
153             Assume, I->getOperand(0),
154             "(align|dereferenceable)"));
155         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
156                                        "(align|dereferenceable)"));
157         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
158                                        "(align|dereferenceable)"));
159         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
160                                        "(nonnull|align|dereferenceable)"));
161         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
162                                      Attribute::AttrKind::Alignment, 32));
163         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
164                                      Attribute::AttrKind::Dereferenceable, 48));
165         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
166                                      Attribute::AttrKind::Dereferenceable, 28));
167         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
168                                      Attribute::AttrKind::Alignment, 8));
169         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
170                                      Attribute::AttrKind::Alignment, 64));
171         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
172                                      Attribute::AttrKind::Dereferenceable, 4));
173         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
174                                      Attribute::AttrKind::Alignment, 16));
175         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
176                                      Attribute::AttrKind::Dereferenceable, 12));
177       }));
178 
179   Tests.push_back(std::make_pair(
180       "call void @func1(i32* readnone align 32 "
181       "dereferenceable(48) noalias %P, i32* "
182       "align 8 dereferenceable(28) %P1, i32* align 64 "
183       "dereferenceable(4) "
184       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
185       [](Instruction *I) {
186         auto *Assume = buildAssumeFromInst(I);
187         Assume->insertBefore(I);
188         I->getOperand(1)->dropDroppableUses();
189         I->getOperand(2)->dropDroppableUses();
190         I->getOperand(3)->dropDroppableUses();
191         ASSERT_TRUE(hasMatchesExactlyAttributes(
192             Assume, I->getOperand(0),
193             "(align|dereferenceable)"));
194         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
195                                        ""));
196         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
197                                        ""));
198         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
199                                        ""));
200         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
201                                      Attribute::AttrKind::Alignment, 32));
202         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
203                                      Attribute::AttrKind::Dereferenceable, 48));
204       }));
205   Tests.push_back(std::make_pair(
206       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
207       "8 noalias %P1, i32* %P1)\n",
208       [](Instruction *I) {
209         auto *Assume = buildAssumeFromInst(I);
210         Assume->insertBefore(I);
211         Value *New = I->getFunction()->getArg(3);
212         Value *Old = I->getOperand(0);
213         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New, ""));
214         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old,
215                                        "(nonnull|align|dereferenceable)"));
216         Old->replaceAllUsesWith(New);
217         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New,
218                                        "(nonnull|align|dereferenceable)"));
219         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old, ""));
220       }));
221   RunTest(Head, Tail, Tests);
222 }
223 
224 static bool FindExactlyAttributes(RetainedKnowledgeMap &Map, Value *WasOn,
225                                  StringRef AttrToMatch) {
226   Regex Reg(AttrToMatch);
227   SmallVector<StringRef, 1> Matches;
228   for (StringRef Attr : {
229 #define GET_ATTR_NAMES
230 #define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
231 #include "llvm/IR/Attributes.inc"
232        }) {
233     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
234 
235     if (ShouldHaveAttr != (Map.contains(RetainedKnowledgeKey{
236                               WasOn, Attribute::getAttrKindFromName(Attr)})))
237       return false;
238   }
239   return true;
240 }
241 
242 static bool MapHasRightValue(RetainedKnowledgeMap &Map, AssumeInst *II,
243                              RetainedKnowledgeKey Key, MinMax MM) {
244   auto LookupIt = Map.find(Key);
245   return (LookupIt != Map.end()) && (LookupIt->second[II].Min == MM.Min) &&
246          (LookupIt->second[II].Max == MM.Max);
247 }
248 
249 TEST(AssumeQueryAPI, fillMapFromAssume) {
250   EnableKnowledgeRetention.setValue(true);
251   StringRef Head =
252       "declare void @llvm.assume(i1)\n"
253       "declare void @func(i32*, i32*, i32*)\n"
254       "declare void @func1(i32*, i32*, i32*, i32*)\n"
255       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
256       "\"less-precise-fpmad\" willreturn norecurse\n"
257       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
258   StringRef Tail = "ret void\n"
259                    "}";
260   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
261       Tests;
262   Tests.push_back(std::make_pair(
263       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
264       "8 noalias %P1, i32* align 8 dereferenceable(8) %P2)\n",
265       [](Instruction *I) {
266         auto *Assume = buildAssumeFromInst(I);
267         Assume->insertBefore(I);
268 
269         RetainedKnowledgeMap Map;
270         fillMapFromAssume(*Assume, Map);
271         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
272                                        "(nonnull|align|dereferenceable)"));
273         ASSERT_FALSE(FindExactlyAttributes(Map, I->getOperand(1),
274                                        "(align)"));
275         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
276                                        "(align|dereferenceable)"));
277         ASSERT_TRUE(MapHasRightValue(
278             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {16, 16}));
279         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
280                                {4, 4}));
281         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
282                                {4, 4}));
283       }));
284   Tests.push_back(std::make_pair(
285       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
286       "nonnull "
287       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
288       "dereferenceable(4) "
289       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
290       [](Instruction *I) {
291         auto *Assume = buildAssumeFromInst(I);
292         Assume->insertBefore(I);
293 
294         RetainedKnowledgeMap Map;
295         fillMapFromAssume(*Assume, Map);
296 
297         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
298                                        "(nonnull|align|dereferenceable)"));
299         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
300                                        "(nonnull|align|dereferenceable)"));
301         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
302                                        "(nonnull|align|dereferenceable)"));
303         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
304                                        "(nonnull|align|dereferenceable)"));
305         ASSERT_TRUE(MapHasRightValue(
306             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable},
307             {48, 48}));
308         ASSERT_TRUE(MapHasRightValue(
309             Map, Assume, {I->getOperand(0), Attribute::Alignment}, {64, 64}));
310       }));
311   Tests.push_back(std::make_pair(
312       "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) {
313         ShouldPreserveAllAttributes.setValue(true);
314         auto *Assume = buildAssumeFromInst(I);
315         Assume->insertBefore(I);
316 
317         RetainedKnowledgeMap Map;
318         fillMapFromAssume(*Assume, Map);
319 
320         ASSERT_TRUE(FindExactlyAttributes(
321             Map, nullptr, "(nounwind|norecurse|willreturn|cold)"));
322         ShouldPreserveAllAttributes.setValue(false);
323       }));
324   Tests.push_back(
325       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
326         RetainedKnowledgeMap Map;
327         fillMapFromAssume(*cast<AssumeInst>(I), Map);
328 
329         ASSERT_TRUE(FindExactlyAttributes(Map, nullptr, ""));
330         ASSERT_TRUE(Map.empty());
331       }));
332   Tests.push_back(std::make_pair(
333       "call void @func1(i32* readnone align 32 "
334       "dereferenceable(48) noalias %P, i32* "
335       "align 8 dereferenceable(28) %P1, i32* align 64 "
336       "dereferenceable(4) "
337       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
338       [](Instruction *I) {
339         auto *Assume = buildAssumeFromInst(I);
340         Assume->insertBefore(I);
341 
342         RetainedKnowledgeMap Map;
343         fillMapFromAssume(*Assume, Map);
344 
345         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
346                                     "(align|dereferenceable)"));
347         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
348                                     "(align|dereferenceable)"));
349         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
350                                        "(align|dereferenceable)"));
351         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
352                                        "(nonnull|align|dereferenceable)"));
353         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
354                                {32, 32}));
355         ASSERT_TRUE(MapHasRightValue(
356             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {48, 48}));
357         ASSERT_TRUE(MapHasRightValue(
358             Map, Assume, {I->getOperand(1), Attribute::Dereferenceable}, {28, 28}));
359         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(1), Attribute::Alignment},
360                                {8, 8}));
361         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(2), Attribute::Alignment},
362                                {64, 64}));
363         ASSERT_TRUE(MapHasRightValue(
364             Map, Assume, {I->getOperand(2), Attribute::Dereferenceable}, {4, 4}));
365         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(3), Attribute::Alignment},
366                                {16, 16}));
367         ASSERT_TRUE(MapHasRightValue(
368             Map, Assume, {I->getOperand(3), Attribute::Dereferenceable}, {12, 12}));
369       }));
370 
371   /// Keep this test last as it modifies the function.
372   Tests.push_back(std::make_pair(
373       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
374       "8 noalias %P1, i32* %P2)\n",
375       [](Instruction *I) {
376         auto *Assume = buildAssumeFromInst(I);
377         Assume->insertBefore(I);
378 
379         RetainedKnowledgeMap Map;
380         fillMapFromAssume(*Assume, Map);
381 
382         Value *New = I->getFunction()->getArg(3);
383         Value *Old = I->getOperand(0);
384         ASSERT_TRUE(FindExactlyAttributes(Map, New, ""));
385         ASSERT_TRUE(FindExactlyAttributes(Map, Old,
386                                        "(nonnull|align|dereferenceable)"));
387         Old->replaceAllUsesWith(New);
388         Map.clear();
389         fillMapFromAssume(*Assume, Map);
390         ASSERT_TRUE(FindExactlyAttributes(Map, New,
391                                        "(nonnull|align|dereferenceable)"));
392         ASSERT_TRUE(FindExactlyAttributes(Map, Old, ""));
393       }));
394   Tests.push_back(std::make_pair(
395       "call void @llvm.assume(i1 true) [\"align\"(i8* undef, i32 undef)]",
396       [](Instruction *I) {
397         // Don't crash but don't learn from undef.
398         RetainedKnowledgeMap Map;
399         fillMapFromAssume(*cast<AssumeInst>(I), Map);
400 
401         ASSERT_TRUE(Map.empty());
402       }));
403   RunTest(Head, Tail, Tests);
404 }
405 
406 static void RunRandTest(uint64_t Seed, int Size, int MinCount, int MaxCount,
407                         unsigned MaxValue) {
408   LLVMContext C;
409   SMDiagnostic Err;
410 
411   std::mt19937 Rng(Seed);
412   std::uniform_int_distribution<int> DistCount(MinCount, MaxCount);
413   std::uniform_int_distribution<unsigned> DistValue(0, MaxValue);
414   std::uniform_int_distribution<unsigned> DistAttr(0,
415                                                    Attribute::EndAttrKinds - 1);
416 
417   std::unique_ptr<Module> Mod = std::make_unique<Module>("AssumeQueryAPI", C);
418   if (!Mod)
419     Err.print("AssumeQueryAPI", errs());
420 
421   std::vector<Type *> TypeArgs;
422   for (int i = 0; i < (Size * 2); i++)
423     TypeArgs.push_back(PointerType::getUnqual(C));
424   FunctionType *FuncType =
425       FunctionType::get(Type::getVoidTy(C), TypeArgs, false);
426 
427   Function *F =
428       Function::Create(FuncType, GlobalValue::ExternalLinkage, "test", &*Mod);
429   BasicBlock *BB = BasicBlock::Create(C);
430   BB->insertInto(F);
431   Instruction *Ret = ReturnInst::Create(C);
432   Ret->insertInto(BB, BB->begin());
433   Function *FnAssume = Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume);
434 
435   std::vector<Argument *> ShuffledArgs;
436   BitVector HasArg;
437   for (auto &Arg : F->args()) {
438     ShuffledArgs.push_back(&Arg);
439     HasArg.push_back(false);
440   }
441 
442   std::shuffle(ShuffledArgs.begin(), ShuffledArgs.end(), Rng);
443 
444   std::vector<OperandBundleDef> OpBundle;
445   OpBundle.reserve(Size);
446   std::vector<Value *> Args;
447   Args.reserve(2);
448   for (int i = 0; i < Size; i++) {
449     int count = DistCount(Rng);
450     int value = DistValue(Rng);
451     int attr = DistAttr(Rng);
452     std::string str;
453     raw_string_ostream ss(str);
454     ss << Attribute::getNameFromAttrKind(
455         static_cast<Attribute::AttrKind>(attr));
456     Args.clear();
457 
458     if (count > 0) {
459       Args.push_back(ShuffledArgs[i]);
460       HasArg[i] = true;
461     }
462     if (count > 1)
463       Args.push_back(ConstantInt::get(Type::getInt32Ty(C), value));
464 
465     OpBundle.push_back(OperandBundleDef{str.c_str(), std::move(Args)});
466   }
467 
468   auto *Assume = cast<AssumeInst>(CallInst::Create(
469       FnAssume, ArrayRef<Value *>({ConstantInt::getTrue(C)}), OpBundle));
470   Assume->insertBefore(&F->begin()->front());
471   RetainedKnowledgeMap Map;
472   fillMapFromAssume(*Assume, Map);
473   for (int i = 0; i < (Size * 2); i++) {
474     if (!HasArg[i])
475       continue;
476     RetainedKnowledge K =
477         getKnowledgeFromUseInAssume(&*ShuffledArgs[i]->use_begin());
478     auto LookupIt = Map.find(RetainedKnowledgeKey{K.WasOn, K.AttrKind});
479     ASSERT_TRUE(LookupIt != Map.end());
480     MinMax MM = LookupIt->second[Assume];
481     ASSERT_TRUE(MM.Min == MM.Max);
482     ASSERT_TRUE(MM.Min == K.ArgValue);
483   }
484 }
485 
486 TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
487   // // For Fuzzing
488   // std::random_device dev;
489   // std::mt19937 Rng(dev());
490   // while (true) {
491   //   unsigned Seed = Rng();
492   //   dbgs() << Seed << "\n";
493   //   RunRandTest(Seed, 100000, 0, 2, 100);
494   // }
495   RunRandTest(23456, 4, 0, 2, 100);
496   RunRandTest(560987, 25, -3, 2, 100);
497 
498   // Large bundles can lead to special cases. this is why this test is soo
499   // large.
500   RunRandTest(9876789, 100000, -0, 7, 100);
501 }
502 
503 TEST(AssumeQueryAPI, AssumptionCache) {
504   LLVMContext C;
505   SMDiagnostic Err;
506   std::unique_ptr<Module> Mod = parseAssemblyString(
507       "declare void @llvm.assume(i1)\n"
508       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
509       "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
510       "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
511       "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
512       "\"dereferenceable\"(i32* %P, i32 4)]\n"
513       "ret void\n}\n",
514       Err, C);
515   if (!Mod)
516     Err.print("AssumeQueryAPI", errs());
517   Function *F = Mod->getFunction("test");
518   BasicBlock::iterator First = F->begin()->begin();
519   BasicBlock::iterator Second = F->begin()->begin();
520   Second++;
521   AssumptionCache AC(*F);
522   auto AR = AC.assumptionsFor(F->getArg(3));
523   ASSERT_EQ(AR.size(), 0u);
524   AR = AC.assumptionsFor(F->getArg(1));
525   ASSERT_EQ(AR.size(), 1u);
526   ASSERT_EQ(AR[0].Index, 0u);
527   ASSERT_EQ(AR[0].Assume, &*Second);
528   AR = AC.assumptionsFor(F->getArg(2));
529   ASSERT_EQ(AR.size(), 1u);
530   ASSERT_EQ(AR[0].Index, 1u);
531   ASSERT_EQ(AR[0].Assume, &*First);
532   AR = AC.assumptionsFor(F->getArg(0));
533   ASSERT_EQ(AR.size(), 3u);
534   llvm::sort(AR,
535              [](const auto &L, const auto &R) { return L.Index < R.Index; });
536   ASSERT_EQ(AR[0].Assume, &*First);
537   ASSERT_EQ(AR[0].Index, 0u);
538   ASSERT_EQ(AR[1].Assume, &*Second);
539   ASSERT_EQ(AR[1].Index, 1u);
540   ASSERT_EQ(AR[2].Assume, &*First);
541   ASSERT_EQ(AR[2].Index, 2u);
542   AR = AC.assumptionsFor(F->getArg(4));
543   ASSERT_EQ(AR.size(), 1u);
544   ASSERT_EQ(AR[0].Assume, &*Second);
545   ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
546   AC.unregisterAssumption(cast<AssumeInst>(&*Second));
547   AR = AC.assumptionsFor(F->getArg(1));
548   ASSERT_EQ(AR.size(), 0u);
549   AR = AC.assumptionsFor(F->getArg(0));
550   ASSERT_EQ(AR.size(), 3u);
551   llvm::sort(AR,
552              [](const auto &L, const auto &R) { return L.Index < R.Index; });
553   ASSERT_EQ(AR[0].Assume, &*First);
554   ASSERT_EQ(AR[0].Index, 0u);
555   ASSERT_EQ(AR[1].Assume, nullptr);
556   ASSERT_EQ(AR[1].Index, 1u);
557   ASSERT_EQ(AR[2].Assume, &*First);
558   ASSERT_EQ(AR[2].Index, 2u);
559   AR = AC.assumptionsFor(F->getArg(2));
560   ASSERT_EQ(AR.size(), 1u);
561   ASSERT_EQ(AR[0].Index, 1u);
562   ASSERT_EQ(AR[0].Assume, &*First);
563 }
564 
565 TEST(AssumeQueryAPI, Alignment) {
566   LLVMContext C;
567   SMDiagnostic Err;
568   std::unique_ptr<Module> Mod = parseAssemblyString(
569       "declare void @llvm.assume(i1)\n"
570       "define void @test(i32* %P, i32* %P1, i32* %P2, i32 %I3, i1 %B) {\n"
571       "call void @llvm.assume(i1 true) [\"align\"(i32* %P, i32 8, i32 %I3)]\n"
572       "call void @llvm.assume(i1 true) [\"align\"(i32* %P1, i32 %I3, i32 "
573       "%I3)]\n"
574       "call void @llvm.assume(i1 true) [\"align\"(i32* %P2, i32 16, i32 8)]\n"
575       "ret void\n}\n",
576       Err, C);
577   if (!Mod)
578     Err.print("AssumeQueryAPI", errs());
579 
580   Function *F = Mod->getFunction("test");
581   BasicBlock::iterator Start = F->begin()->begin();
582   AssumeInst *II;
583   RetainedKnowledge RK;
584   II = cast<AssumeInst>(&*Start);
585   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
586   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
587   ASSERT_EQ(RK.WasOn, F->getArg(0));
588   ASSERT_EQ(RK.ArgValue, 1u);
589   Start++;
590   II = cast<AssumeInst>(&*Start);
591   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
592   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
593   ASSERT_EQ(RK.WasOn, F->getArg(1));
594   ASSERT_EQ(RK.ArgValue, 1u);
595   Start++;
596   II = cast<AssumeInst>(&*Start);
597   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
598   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
599   ASSERT_EQ(RK.WasOn, F->getArg(2));
600   ASSERT_EQ(RK.ArgValue, 8u);
601 }
602