xref: /llvm-project/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp (revision be252b4e28ad1d964500079114ea0e0a56bb0a19)
1 //===- AssumeBundleQueriesTest.cpp ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/AssumeBundleQueries.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/IntrinsicInst.h"
14 #include "llvm/Support/Regex.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
18 #include "gtest/gtest.h"
19 #include <random>
20 
21 using namespace llvm;
22 
23 namespace llvm {
24 extern cl::opt<bool> ShouldPreserveAllAttributes;
25 } // namespace llvm
26 
27 static void RunTest(
28     StringRef Head, StringRef Tail,
29     std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
30         &Tests) {
31   for (auto &Elem : Tests) {
32     std::string IR;
33     IR.append(Head.begin(), Head.end());
34     IR.append(Elem.first.begin(), Elem.first.end());
35     IR.append(Tail.begin(), Tail.end());
36     LLVMContext C;
37     SMDiagnostic Err;
38     std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
39     if (!Mod)
40       Err.print("AssumeQueryAPI", errs());
41     Elem.second(&*(Mod->getFunction("test")->begin()->begin()));
42   }
43 }
44 
45 bool hasMatchesExactlyAttributes(AssumeInst *Assume, Value *WasOn,
46                                  StringRef AttrToMatch) {
47   Regex Reg(AttrToMatch);
48   SmallVector<StringRef, 1> Matches;
49   for (StringRef Attr : {
50 #define GET_ATTR_NAMES
51 #define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
52 #include "llvm/IR/Attributes.inc"
53        }) {
54     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
55     if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr))
56       return false;
57   }
58   return true;
59 }
60 
61 bool hasTheRightValue(AssumeInst *Assume, Value *WasOn,
62                       Attribute::AttrKind Kind, unsigned Value) {
63   uint64_t ArgVal = 0;
64   if (!hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal))
65     return false;
66   if (ArgVal != Value)
67     return false;
68   return true;
69 }
70 
71 TEST(AssumeQueryAPI, hasAttributeInAssume) {
72   EnableKnowledgeRetention.setValue(true);
73   StringRef Head =
74       "declare void @llvm.assume(i1)\n"
75       "declare void @func(i32*, i32*, i32*)\n"
76       "declare void @func1(i32*, i32*, i32*, i32*)\n"
77       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
78       "\"less-precise-fpmad\" willreturn norecurse\n"
79       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
80   StringRef Tail = "ret void\n"
81                    "}";
82   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
83       Tests;
84   Tests.push_back(std::make_pair(
85       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
86       "8 noalias %P1, i32* align 8 noundef %P2)\n",
87       [](Instruction *I) {
88         auto *Assume = buildAssumeFromInst(I);
89         Assume->insertBefore(I);
90         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
91                                        "(nonnull|align|dereferenceable)"));
92         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
93                                        "()"));
94         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
95                                        "(align|noundef)"));
96         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
97                                      Attribute::AttrKind::Dereferenceable, 16));
98         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
99                                      Attribute::AttrKind::Alignment, 4));
100         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
101                                      Attribute::AttrKind::Alignment, 4));
102       }));
103   Tests.push_back(std::make_pair(
104       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
105       "nonnull "
106       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
107       "dereferenceable(4) "
108       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
109       [](Instruction *I) {
110         auto *Assume = buildAssumeFromInst(I);
111         Assume->insertBefore(I);
112         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(0),
113                                        "(nonnull|align|dereferenceable)"));
114         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
115                                        "(nonnull|align|dereferenceable)"));
116         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
117                                        "(nonnull|align|dereferenceable)"));
118         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
119                                        "(nonnull|align|dereferenceable)"));
120         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
121                                      Attribute::AttrKind::Dereferenceable, 48));
122         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
123                                      Attribute::AttrKind::Alignment, 64));
124         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
125                                      Attribute::AttrKind::Alignment, 64));
126       }));
127   Tests.push_back(std::make_pair(
128       "call void @func_many(i32* align 8 noundef %P1) cold\n", [](Instruction *I) {
129         ShouldPreserveAllAttributes.setValue(true);
130         auto *Assume = buildAssumeFromInst(I);
131         Assume->insertBefore(I);
132         ASSERT_TRUE(hasMatchesExactlyAttributes(
133             Assume, nullptr,
134             "(align|nounwind|norecurse|noundef|willreturn|cold)"));
135         ShouldPreserveAllAttributes.setValue(false);
136       }));
137   Tests.push_back(
138       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
139         auto *Assume = cast<AssumeInst>(I);
140         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, nullptr, ""));
141       }));
142   Tests.push_back(std::make_pair(
143       "call void @func1(i32* readnone align 32 "
144       "dereferenceable(48) noalias %P, i32* "
145       "align 8 dereferenceable(28) %P1, i32* align 64 "
146       "dereferenceable(4) "
147       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
148       [](Instruction *I) {
149         auto *Assume = buildAssumeFromInst(I);
150         Assume->insertBefore(I);
151         ASSERT_TRUE(hasMatchesExactlyAttributes(
152             Assume, I->getOperand(0),
153             "(align|dereferenceable)"));
154         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
155                                        "(align|dereferenceable)"));
156         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
157                                        "(align|dereferenceable)"));
158         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
159                                        "(nonnull|align|dereferenceable)"));
160         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
161                                      Attribute::AttrKind::Alignment, 32));
162         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
163                                      Attribute::AttrKind::Dereferenceable, 48));
164         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
165                                      Attribute::AttrKind::Dereferenceable, 28));
166         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(1),
167                                      Attribute::AttrKind::Alignment, 8));
168         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
169                                      Attribute::AttrKind::Alignment, 64));
170         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(2),
171                                      Attribute::AttrKind::Dereferenceable, 4));
172         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
173                                      Attribute::AttrKind::Alignment, 16));
174         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(3),
175                                      Attribute::AttrKind::Dereferenceable, 12));
176       }));
177 
178   Tests.push_back(std::make_pair(
179       "call void @func1(i32* readnone align 32 "
180       "dereferenceable(48) noalias %P, i32* "
181       "align 8 dereferenceable(28) %P1, i32* align 64 "
182       "dereferenceable(4) "
183       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
184       [](Instruction *I) {
185         auto *Assume = buildAssumeFromInst(I);
186         Assume->insertBefore(I);
187         I->getOperand(1)->dropDroppableUses();
188         I->getOperand(2)->dropDroppableUses();
189         I->getOperand(3)->dropDroppableUses();
190         ASSERT_TRUE(hasMatchesExactlyAttributes(
191             Assume, I->getOperand(0),
192             "(align|dereferenceable)"));
193         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(1),
194                                        ""));
195         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(2),
196                                        ""));
197         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, I->getOperand(3),
198                                        ""));
199         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
200                                      Attribute::AttrKind::Alignment, 32));
201         ASSERT_TRUE(hasTheRightValue(Assume, I->getOperand(0),
202                                      Attribute::AttrKind::Dereferenceable, 48));
203       }));
204   Tests.push_back(std::make_pair(
205       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
206       "8 noalias %P1, i32* %P1)\n",
207       [](Instruction *I) {
208         auto *Assume = buildAssumeFromInst(I);
209         Assume->insertBefore(I);
210         Value *New = I->getFunction()->getArg(3);
211         Value *Old = I->getOperand(0);
212         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New, ""));
213         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old,
214                                        "(nonnull|align|dereferenceable)"));
215         Old->replaceAllUsesWith(New);
216         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, New,
217                                        "(nonnull|align|dereferenceable)"));
218         ASSERT_TRUE(hasMatchesExactlyAttributes(Assume, Old, ""));
219       }));
220   RunTest(Head, Tail, Tests);
221 }
222 
223 static bool FindExactlyAttributes(RetainedKnowledgeMap &Map, Value *WasOn,
224                                  StringRef AttrToMatch) {
225   Regex Reg(AttrToMatch);
226   SmallVector<StringRef, 1> Matches;
227   for (StringRef Attr : {
228 #define GET_ATTR_NAMES
229 #define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
230 #include "llvm/IR/Attributes.inc"
231        }) {
232     bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
233 
234     if (ShouldHaveAttr != (Map.contains(RetainedKnowledgeKey{
235                               WasOn, Attribute::getAttrKindFromName(Attr)})))
236       return false;
237   }
238   return true;
239 }
240 
241 static bool MapHasRightValue(RetainedKnowledgeMap &Map, AssumeInst *II,
242                              RetainedKnowledgeKey Key, MinMax MM) {
243   auto LookupIt = Map.find(Key);
244   return (LookupIt != Map.end()) && (LookupIt->second[II].Min == MM.Min) &&
245          (LookupIt->second[II].Max == MM.Max);
246 }
247 
248 TEST(AssumeQueryAPI, fillMapFromAssume) {
249   EnableKnowledgeRetention.setValue(true);
250   StringRef Head =
251       "declare void @llvm.assume(i1)\n"
252       "declare void @func(i32*, i32*, i32*)\n"
253       "declare void @func1(i32*, i32*, i32*, i32*)\n"
254       "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
255       "\"less-precise-fpmad\" willreturn norecurse\n"
256       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
257   StringRef Tail = "ret void\n"
258                    "}";
259   std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
260       Tests;
261   Tests.push_back(std::make_pair(
262       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
263       "8 noalias %P1, i32* align 8 dereferenceable(8) %P2)\n",
264       [](Instruction *I) {
265         auto *Assume = buildAssumeFromInst(I);
266         Assume->insertBefore(I);
267 
268         RetainedKnowledgeMap Map;
269         fillMapFromAssume(*Assume, Map);
270         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
271                                        "(nonnull|align|dereferenceable)"));
272         ASSERT_FALSE(FindExactlyAttributes(Map, I->getOperand(1),
273                                        "(align)"));
274         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
275                                        "(align|dereferenceable)"));
276         ASSERT_TRUE(MapHasRightValue(
277             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {16, 16}));
278         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
279                                {4, 4}));
280         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
281                                {4, 4}));
282       }));
283   Tests.push_back(std::make_pair(
284       "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
285       "nonnull "
286       "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
287       "dereferenceable(4) "
288       "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
289       [](Instruction *I) {
290         auto *Assume = buildAssumeFromInst(I);
291         Assume->insertBefore(I);
292 
293         RetainedKnowledgeMap Map;
294         fillMapFromAssume(*Assume, Map);
295 
296         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
297                                        "(nonnull|align|dereferenceable)"));
298         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
299                                        "(nonnull|align|dereferenceable)"));
300         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
301                                        "(nonnull|align|dereferenceable)"));
302         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
303                                        "(nonnull|align|dereferenceable)"));
304         ASSERT_TRUE(MapHasRightValue(
305             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable},
306             {48, 48}));
307         ASSERT_TRUE(MapHasRightValue(
308             Map, Assume, {I->getOperand(0), Attribute::Alignment}, {64, 64}));
309       }));
310   Tests.push_back(std::make_pair(
311       "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) {
312         ShouldPreserveAllAttributes.setValue(true);
313         auto *Assume = buildAssumeFromInst(I);
314         Assume->insertBefore(I);
315 
316         RetainedKnowledgeMap Map;
317         fillMapFromAssume(*Assume, Map);
318 
319         ASSERT_TRUE(FindExactlyAttributes(
320             Map, nullptr, "(nounwind|norecurse|willreturn|cold)"));
321         ShouldPreserveAllAttributes.setValue(false);
322       }));
323   Tests.push_back(
324       std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
325         RetainedKnowledgeMap Map;
326         fillMapFromAssume(*cast<AssumeInst>(I), Map);
327 
328         ASSERT_TRUE(FindExactlyAttributes(Map, nullptr, ""));
329         ASSERT_TRUE(Map.empty());
330       }));
331   Tests.push_back(std::make_pair(
332       "call void @func1(i32* readnone align 32 "
333       "dereferenceable(48) noalias %P, i32* "
334       "align 8 dereferenceable(28) %P1, i32* align 64 "
335       "dereferenceable(4) "
336       "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
337       [](Instruction *I) {
338         auto *Assume = buildAssumeFromInst(I);
339         Assume->insertBefore(I);
340 
341         RetainedKnowledgeMap Map;
342         fillMapFromAssume(*Assume, Map);
343 
344         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(0),
345                                     "(align|dereferenceable)"));
346         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(1),
347                                     "(align|dereferenceable)"));
348         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(2),
349                                        "(align|dereferenceable)"));
350         ASSERT_TRUE(FindExactlyAttributes(Map, I->getOperand(3),
351                                        "(nonnull|align|dereferenceable)"));
352         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(0), Attribute::Alignment},
353                                {32, 32}));
354         ASSERT_TRUE(MapHasRightValue(
355             Map, Assume, {I->getOperand(0), Attribute::Dereferenceable}, {48, 48}));
356         ASSERT_TRUE(MapHasRightValue(
357             Map, Assume, {I->getOperand(1), Attribute::Dereferenceable}, {28, 28}));
358         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(1), Attribute::Alignment},
359                                {8, 8}));
360         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(2), Attribute::Alignment},
361                                {64, 64}));
362         ASSERT_TRUE(MapHasRightValue(
363             Map, Assume, {I->getOperand(2), Attribute::Dereferenceable}, {4, 4}));
364         ASSERT_TRUE(MapHasRightValue(Map, Assume, {I->getOperand(3), Attribute::Alignment},
365                                {16, 16}));
366         ASSERT_TRUE(MapHasRightValue(
367             Map, Assume, {I->getOperand(3), Attribute::Dereferenceable}, {12, 12}));
368       }));
369 
370   /// Keep this test last as it modifies the function.
371   Tests.push_back(std::make_pair(
372       "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
373       "8 noalias %P1, i32* %P2)\n",
374       [](Instruction *I) {
375         auto *Assume = buildAssumeFromInst(I);
376         Assume->insertBefore(I);
377 
378         RetainedKnowledgeMap Map;
379         fillMapFromAssume(*Assume, Map);
380 
381         Value *New = I->getFunction()->getArg(3);
382         Value *Old = I->getOperand(0);
383         ASSERT_TRUE(FindExactlyAttributes(Map, New, ""));
384         ASSERT_TRUE(FindExactlyAttributes(Map, Old,
385                                        "(nonnull|align|dereferenceable)"));
386         Old->replaceAllUsesWith(New);
387         Map.clear();
388         fillMapFromAssume(*Assume, Map);
389         ASSERT_TRUE(FindExactlyAttributes(Map, New,
390                                        "(nonnull|align|dereferenceable)"));
391         ASSERT_TRUE(FindExactlyAttributes(Map, Old, ""));
392       }));
393   Tests.push_back(std::make_pair(
394       "call void @llvm.assume(i1 true) [\"align\"(i8* undef, i32 undef)]",
395       [](Instruction *I) {
396         // Don't crash but don't learn from undef.
397         RetainedKnowledgeMap Map;
398         fillMapFromAssume(*cast<AssumeInst>(I), Map);
399 
400         ASSERT_TRUE(Map.empty());
401       }));
402   RunTest(Head, Tail, Tests);
403 }
404 
405 static void RunRandTest(uint64_t Seed, int Size, int MinCount, int MaxCount,
406                         unsigned MaxValue) {
407   LLVMContext C;
408   SMDiagnostic Err;
409 
410   std::mt19937 Rng(Seed);
411   std::uniform_int_distribution<int> DistCount(MinCount, MaxCount);
412   std::uniform_int_distribution<unsigned> DistValue(0, MaxValue);
413   std::uniform_int_distribution<unsigned> DistAttr(0,
414                                                    Attribute::EndAttrKinds - 1);
415 
416   std::unique_ptr<Module> Mod = std::make_unique<Module>("AssumeQueryAPI", C);
417   if (!Mod)
418     Err.print("AssumeQueryAPI", errs());
419 
420   std::vector<Type *> TypeArgs;
421   for (int i = 0; i < (Size * 2); i++)
422     TypeArgs.push_back(PointerType::getUnqual(C));
423   FunctionType *FuncType =
424       FunctionType::get(Type::getVoidTy(C), TypeArgs, false);
425 
426   Function *F =
427       Function::Create(FuncType, GlobalValue::ExternalLinkage, "test", &*Mod);
428   BasicBlock *BB = BasicBlock::Create(C);
429   BB->insertInto(F);
430   Instruction *Ret = ReturnInst::Create(C);
431   Ret->insertInto(BB, BB->begin());
432   Function *FnAssume = Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume);
433 
434   std::vector<Argument *> ShuffledArgs;
435   BitVector HasArg;
436   for (auto &Arg : F->args()) {
437     ShuffledArgs.push_back(&Arg);
438     HasArg.push_back(false);
439   }
440 
441   std::shuffle(ShuffledArgs.begin(), ShuffledArgs.end(), Rng);
442 
443   std::vector<OperandBundleDef> OpBundle;
444   OpBundle.reserve(Size);
445   std::vector<Value *> Args;
446   Args.reserve(2);
447   for (int i = 0; i < Size; i++) {
448     int count = DistCount(Rng);
449     int value = DistValue(Rng);
450     int attr = DistAttr(Rng);
451     std::string str;
452     raw_string_ostream ss(str);
453     ss << Attribute::getNameFromAttrKind(
454         static_cast<Attribute::AttrKind>(attr));
455     Args.clear();
456 
457     if (count > 0) {
458       Args.push_back(ShuffledArgs[i]);
459       HasArg[i] = true;
460     }
461     if (count > 1)
462       Args.push_back(ConstantInt::get(Type::getInt32Ty(C), value));
463 
464     OpBundle.push_back(OperandBundleDef{ss.str().c_str(), std::move(Args)});
465   }
466 
467   auto *Assume = cast<AssumeInst>(CallInst::Create(
468       FnAssume, ArrayRef<Value *>({ConstantInt::getTrue(C)}), OpBundle));
469   Assume->insertBefore(&F->begin()->front());
470   RetainedKnowledgeMap Map;
471   fillMapFromAssume(*Assume, Map);
472   for (int i = 0; i < (Size * 2); i++) {
473     if (!HasArg[i])
474       continue;
475     RetainedKnowledge K =
476         getKnowledgeFromUseInAssume(&*ShuffledArgs[i]->use_begin());
477     auto LookupIt = Map.find(RetainedKnowledgeKey{K.WasOn, K.AttrKind});
478     ASSERT_TRUE(LookupIt != Map.end());
479     MinMax MM = LookupIt->second[Assume];
480     ASSERT_TRUE(MM.Min == MM.Max);
481     ASSERT_TRUE(MM.Min == K.ArgValue);
482   }
483 }
484 
485 TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
486   // // For Fuzzing
487   // std::random_device dev;
488   // std::mt19937 Rng(dev());
489   // while (true) {
490   //   unsigned Seed = Rng();
491   //   dbgs() << Seed << "\n";
492   //   RunRandTest(Seed, 100000, 0, 2, 100);
493   // }
494   RunRandTest(23456, 4, 0, 2, 100);
495   RunRandTest(560987, 25, -3, 2, 100);
496 
497   // Large bundles can lead to special cases. this is why this test is soo
498   // large.
499   RunRandTest(9876789, 100000, -0, 7, 100);
500 }
501 
502 TEST(AssumeQueryAPI, AssumptionCache) {
503   LLVMContext C;
504   SMDiagnostic Err;
505   std::unique_ptr<Module> Mod = parseAssemblyString(
506       "declare void @llvm.assume(i1)\n"
507       "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
508       "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
509       "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
510       "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
511       "\"dereferenceable\"(i32* %P, i32 4)]\n"
512       "ret void\n}\n",
513       Err, C);
514   if (!Mod)
515     Err.print("AssumeQueryAPI", errs());
516   Function *F = Mod->getFunction("test");
517   BasicBlock::iterator First = F->begin()->begin();
518   BasicBlock::iterator Second = F->begin()->begin();
519   Second++;
520   AssumptionCache AC(*F);
521   auto AR = AC.assumptionsFor(F->getArg(3));
522   ASSERT_EQ(AR.size(), 0u);
523   AR = AC.assumptionsFor(F->getArg(1));
524   ASSERT_EQ(AR.size(), 1u);
525   ASSERT_EQ(AR[0].Index, 0u);
526   ASSERT_EQ(AR[0].Assume, &*Second);
527   AR = AC.assumptionsFor(F->getArg(2));
528   ASSERT_EQ(AR.size(), 1u);
529   ASSERT_EQ(AR[0].Index, 1u);
530   ASSERT_EQ(AR[0].Assume, &*First);
531   AR = AC.assumptionsFor(F->getArg(0));
532   ASSERT_EQ(AR.size(), 3u);
533   llvm::sort(AR,
534              [](const auto &L, const auto &R) { return L.Index < R.Index; });
535   ASSERT_EQ(AR[0].Assume, &*First);
536   ASSERT_EQ(AR[0].Index, 0u);
537   ASSERT_EQ(AR[1].Assume, &*Second);
538   ASSERT_EQ(AR[1].Index, 1u);
539   ASSERT_EQ(AR[2].Assume, &*First);
540   ASSERT_EQ(AR[2].Index, 2u);
541   AR = AC.assumptionsFor(F->getArg(4));
542   ASSERT_EQ(AR.size(), 1u);
543   ASSERT_EQ(AR[0].Assume, &*Second);
544   ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
545   AC.unregisterAssumption(cast<AssumeInst>(&*Second));
546   AR = AC.assumptionsFor(F->getArg(1));
547   ASSERT_EQ(AR.size(), 0u);
548   AR = AC.assumptionsFor(F->getArg(0));
549   ASSERT_EQ(AR.size(), 3u);
550   llvm::sort(AR,
551              [](const auto &L, const auto &R) { return L.Index < R.Index; });
552   ASSERT_EQ(AR[0].Assume, &*First);
553   ASSERT_EQ(AR[0].Index, 0u);
554   ASSERT_EQ(AR[1].Assume, nullptr);
555   ASSERT_EQ(AR[1].Index, 1u);
556   ASSERT_EQ(AR[2].Assume, &*First);
557   ASSERT_EQ(AR[2].Index, 2u);
558   AR = AC.assumptionsFor(F->getArg(2));
559   ASSERT_EQ(AR.size(), 1u);
560   ASSERT_EQ(AR[0].Index, 1u);
561   ASSERT_EQ(AR[0].Assume, &*First);
562 }
563 
564 TEST(AssumeQueryAPI, Alignment) {
565   LLVMContext C;
566   SMDiagnostic Err;
567   std::unique_ptr<Module> Mod = parseAssemblyString(
568       "declare void @llvm.assume(i1)\n"
569       "define void @test(i32* %P, i32* %P1, i32* %P2, i32 %I3, i1 %B) {\n"
570       "call void @llvm.assume(i1 true) [\"align\"(i32* %P, i32 8, i32 %I3)]\n"
571       "call void @llvm.assume(i1 true) [\"align\"(i32* %P1, i32 %I3, i32 "
572       "%I3)]\n"
573       "call void @llvm.assume(i1 true) [\"align\"(i32* %P2, i32 16, i32 8)]\n"
574       "ret void\n}\n",
575       Err, C);
576   if (!Mod)
577     Err.print("AssumeQueryAPI", errs());
578 
579   Function *F = Mod->getFunction("test");
580   BasicBlock::iterator Start = F->begin()->begin();
581   AssumeInst *II;
582   RetainedKnowledge RK;
583   II = cast<AssumeInst>(&*Start);
584   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
585   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
586   ASSERT_EQ(RK.WasOn, F->getArg(0));
587   ASSERT_EQ(RK.ArgValue, 1u);
588   Start++;
589   II = cast<AssumeInst>(&*Start);
590   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
591   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
592   ASSERT_EQ(RK.WasOn, F->getArg(1));
593   ASSERT_EQ(RK.ArgValue, 1u);
594   Start++;
595   II = cast<AssumeInst>(&*Start);
596   RK = getKnowledgeFromBundle(*II, II->bundle_op_info_begin()[0]);
597   ASSERT_EQ(RK.AttrKind, Attribute::Alignment);
598   ASSERT_EQ(RK.WasOn, F->getArg(2));
599   ASSERT_EQ(RK.ArgValue, 8u);
600 }
601