xref: /llvm-project/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp (revision 64b9753d03946d8100e017a5cc4861d5d671c6d0)
1 //=== ScalarEvolutionExpanderTest.cpp - ScalarEvolutionExpander unit tests ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/IR/Verifier.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "gtest/gtest.h"
27 
28 namespace llvm {
29 
30 using namespace PatternMatch;
31 
32 // We use this fixture to ensure that we clean up ScalarEvolution before
33 // deleting the PassManager.
34 class ScalarEvolutionExpanderTest : public testing::Test {
35 protected:
36   LLVMContext Context;
37   Module M;
38   TargetLibraryInfoImpl TLII;
39   TargetLibraryInfo TLI;
40 
41   std::unique_ptr<AssumptionCache> AC;
42   std::unique_ptr<DominatorTree> DT;
43   std::unique_ptr<LoopInfo> LI;
44 
45   ScalarEvolutionExpanderTest() : M("", Context), TLII(), TLI(TLII) {}
46 
47   ScalarEvolution buildSE(Function &F) {
48     AC.reset(new AssumptionCache(F));
49     DT.reset(new DominatorTree(F));
50     LI.reset(new LoopInfo(*DT));
51     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
52   }
53 
54   void runWithSE(
55       Module &M, StringRef FuncName,
56       function_ref<void(Function &F, LoopInfo &LI, ScalarEvolution &SE)> Test) {
57     auto *F = M.getFunction(FuncName);
58     ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
59     ScalarEvolution SE = buildSE(*F);
60     Test(*F, *LI, SE);
61   }
62 };
63 
64 static Instruction &GetInstByName(Function &F, StringRef Name) {
65   for (auto &I : instructions(F))
66     if (I.getName() == Name)
67       return I;
68   llvm_unreachable("Could not find instructions!");
69 }
70 
71 TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
72   // It is to test the fix for PR30213. It exercises the branch in scev
73   // expansion when the value in ValueOffsetPair is a ptr and the offset
74   // is not divisible by the elem type size of value.
75   auto *I8Ty = Type::getInt8Ty(Context);
76   auto *PtrTy = PointerType::get(Context, 0);
77   auto *I32Ty = Type::getInt32Ty(Context);
78   FunctionType *FTy =
79       FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
80   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
81   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
82   BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
83   BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
84   BranchInst::Create(LoopBB, EntryBB);
85   ReturnInst::Create(Context, nullptr, ExitBB);
86 
87   // loop:                            ; preds = %loop, %entry
88   //   %alloca = alloca i32
89   //   %gep0 = getelementptr i32, ptr %alloca, i32 1
90   //   %gep1 = getelementptr i8, ptr %gep0, i32 1
91   //   %gep2 = getelementptr i8, ptr undef, i32 1
92   //   %cmp = icmp ult ptr undef, %gep0
93   //   %select = select i1 %cmp, ptr %gep1, ptr %gep2
94   //   br i1 undef, label %loop, label %exit
95 
96   const DataLayout &DL = F->getDataLayout();
97   BranchInst *Br = BranchInst::Create(
98       LoopBB, ExitBB, PoisonValue::get(Type::getInt1Ty(Context)), LoopBB);
99   AllocaInst *Alloca = new AllocaInst(I32Ty, DL.getAllocaAddrSpace(), "alloca",
100                                       Br->getIterator());
101   ConstantInt *Ci32 = ConstantInt::get(Context, APInt(32, 1));
102   UndefValue *UndefPtr = UndefValue::get(PtrTy);
103   GetElementPtrInst *Gep0 =
104       GetElementPtrInst::Create(I32Ty, Alloca, Ci32, "gep0", Br->getIterator());
105   GetElementPtrInst *Gep1 =
106       GetElementPtrInst::Create(I8Ty, Gep0, Ci32, "gep1", Br->getIterator());
107   GetElementPtrInst *Gep2 = GetElementPtrInst::Create(
108       I8Ty, UndefPtr, Ci32, "gep2", Br->getIterator());
109   CmpInst *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
110                                  UndefPtr, Gep0, "cmp", Br->getIterator());
111   SelectInst *Select =
112       SelectInst::Create(Cmp, Gep1, Gep2, "select", Br->getIterator());
113 
114   ScalarEvolution SE = buildSE(*F);
115   const SCEV *S = SE.getSCEV(Select);
116   EXPECT_TRUE(isa<SCEVUnknown>(S));
117 }
118 
119 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
120 TEST_F(ScalarEvolutionExpanderTest, SCEVZeroExtendExprNonIntegral) {
121   /*
122    * Create the following code:
123    * func(i64 addrspace(10)* %arg)
124    * top:
125    *  br label %L.ph
126    * L.ph:
127    *  %gepbase = getelementptr i64 addrspace(10)* %arg, i64 1
128    *  br label %L
129    * L:
130    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
131    *  %add = add i64 %phi2, 1
132    *  br i1 undef, label %post, label %L2
133    * post:
134    *  #= %gep = getelementptr i64 addrspace(10)* %gepbase, i64 %add =#
135    *  ret void
136    *
137    * We will create the appropriate SCEV expression for %gep and expand it,
138    * then check that no inttoptr/ptrtoint instructions got inserted.
139    */
140 
141   // Create a module with non-integral pointers in it's datalayout
142   Module NIM("nonintegral", Context);
143   std::string DataLayout = M.getDataLayoutStr();
144   if (!DataLayout.empty())
145     DataLayout += "-";
146   DataLayout += "ni:10";
147   NIM.setDataLayout(DataLayout);
148 
149   Type *T_int1 = Type::getInt1Ty(Context);
150   Type *T_int64 = Type::getInt64Ty(Context);
151   Type *T_pint64 = PointerType::get(Context, 10);
152 
153   FunctionType *FTy =
154       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
155   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
156 
157   Argument *Arg = &*F->arg_begin();
158 
159   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
160   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
161   BasicBlock *L = BasicBlock::Create(Context, "L", F);
162   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
163 
164   IRBuilder<> Builder(Top);
165   Builder.CreateBr(LPh);
166 
167   Builder.SetInsertPoint(LPh);
168   Value *GepBase =
169       Builder.CreateGEP(T_int64, Arg, ConstantInt::get(T_int64, 1));
170   Builder.CreateBr(L);
171 
172   Builder.SetInsertPoint(L);
173   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
174   Value *Add = Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add");
175   Builder.CreateCondBr(PoisonValue::get(T_int1), L, Post);
176   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
177   Phi->addIncoming(Add, L);
178 
179   Builder.SetInsertPoint(Post);
180   Instruction *Ret = Builder.CreateRetVoid();
181 
182   ScalarEvolution SE = buildSE(*F);
183   const SCEV *AddRec =
184       SE.getAddRecExpr(SE.getUnknown(GepBase), SE.getConstant(T_int64, 1),
185                        LI->getLoopFor(L), SCEV::FlagNUW);
186 
187   SCEVExpander Exp(SE, NIM.getDataLayout(), "expander");
188   Exp.disableCanonicalMode();
189   Exp.expandCodeFor(AddRec, T_pint64, Ret);
190 
191   // Make sure none of the instructions inserted were inttoptr/ptrtoint.
192   // The verifier will check this.
193   EXPECT_FALSE(verifyFunction(*F, &errs()));
194 }
195 
196 // Check that we can correctly identify the points at which the SCEV of the
197 // AddRec can be expanded.
198 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
199   /*
200    * Create the following code:
201    * func(i64 addrspace(10)* %arg)
202    * top:
203    *  br label %L.ph
204    * L.ph:
205    *  br label %L
206    * L:
207    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
208    *  %add = add i64 %phi2, 1
209    *  %cond = icmp slt i64 %add, 1000; then becomes 2000.
210    *  br i1 %cond, label %post, label %L2
211    * post:
212    *  ret void
213    *
214    */
215 
216   // Create a module with non-integral pointers in it's datalayout
217   Module NIM("nonintegral", Context);
218   std::string DataLayout = M.getDataLayoutStr();
219   if (!DataLayout.empty())
220     DataLayout += "-";
221   DataLayout += "ni:10";
222   NIM.setDataLayout(DataLayout);
223 
224   Type *T_int64 = Type::getInt64Ty(Context);
225   Type *T_pint64 = PointerType::get(Context, 10);
226 
227   FunctionType *FTy =
228       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
229   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
230 
231   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
232   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
233   BasicBlock *L = BasicBlock::Create(Context, "L", F);
234   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
235 
236   IRBuilder<> Builder(Top);
237   Builder.CreateBr(LPh);
238 
239   Builder.SetInsertPoint(LPh);
240   Builder.CreateBr(L);
241 
242   Builder.SetInsertPoint(L);
243   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
244   auto *Add = cast<Instruction>(
245       Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
246   auto *Limit = ConstantInt::get(T_int64, 1000);
247   auto *Cond = cast<Instruction>(
248       Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Limit, "cond"));
249   Builder.CreateCondBr(Cond, L, Post);
250   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
251   Phi->addIncoming(Add, L);
252 
253   Builder.SetInsertPoint(Post);
254   Instruction *Ret = Builder.CreateRetVoid();
255 
256   ScalarEvolution SE = buildSE(*F);
257   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
258   const SCEV *S = SE.getSCEV(Phi);
259   EXPECT_TRUE(isa<SCEVAddRecExpr>(S));
260   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
261   EXPECT_TRUE(AR->isAffine());
262   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, Top->getTerminator()));
263   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, LPh->getTerminator()));
264   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, L->getTerminator()));
265   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, Post->getTerminator()));
266 
267   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
268   Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
269   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
270 }
271 
272 // Check that SCEV expander does not use the nuw instruction
273 // for expansion.
274 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNUW) {
275   /*
276    * Create the following code:
277    * func(i64 %a)
278    * entry:
279    *   br false, label %exit, label %body
280    * body:
281    *  %s1 = add i64 %a, -1
282    *  br label %exit
283    * exit:
284    *  %s = add nuw i64 %a, -1
285    *  ret %s
286    */
287 
288   // Create a module.
289   Module M("SCEVExpanderNUW", Context);
290 
291   Type *T_int64 = Type::getInt64Ty(Context);
292 
293   FunctionType *FTy =
294       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
295   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
296   Argument *Arg = &*F->arg_begin();
297   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
298 
299   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
300   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
301   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
302 
303   IRBuilder<> Builder(Entry);
304   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
305   Builder.CreateCondBr(Cond, Exit, Body);
306 
307   Builder.SetInsertPoint(Body);
308   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
309   Builder.CreateBr(Exit);
310 
311   Builder.SetInsertPoint(Exit);
312   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
313   S2->setHasNoUnsignedWrap(true);
314   auto *R = cast<Instruction>(Builder.CreateRetVoid());
315 
316   ScalarEvolution SE = buildSE(*F);
317   const SCEV *S = SE.getSCEV(S1);
318   EXPECT_TRUE(isa<SCEVAddExpr>(S));
319   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
320   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
321   EXPECT_FALSE(I->hasNoUnsignedWrap());
322 }
323 
324 // Check that SCEV expander does not use the nsw instruction
325 // for expansion.
326 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNSW) {
327   /*
328    * Create the following code:
329    * func(i64 %a)
330    * entry:
331    *   br false, label %exit, label %body
332    * body:
333    *  %s1 = add i64 %a, -1
334    *  br label %exit
335    * exit:
336    *  %s = add nsw i64 %a, -1
337    *  ret %s
338    */
339 
340   // Create a module.
341   Module M("SCEVExpanderNSW", Context);
342 
343   Type *T_int64 = Type::getInt64Ty(Context);
344 
345   FunctionType *FTy =
346       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
347   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
348   Argument *Arg = &*F->arg_begin();
349   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
350 
351   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
352   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
353   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
354 
355   IRBuilder<> Builder(Entry);
356   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
357   Builder.CreateCondBr(Cond, Exit, Body);
358 
359   Builder.SetInsertPoint(Body);
360   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
361   Builder.CreateBr(Exit);
362 
363   Builder.SetInsertPoint(Exit);
364   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
365   S2->setHasNoSignedWrap(true);
366   auto *R = cast<Instruction>(Builder.CreateRetVoid());
367 
368   ScalarEvolution SE = buildSE(*F);
369   const SCEV *S = SE.getSCEV(S1);
370   EXPECT_TRUE(isa<SCEVAddExpr>(S));
371   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
372   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
373   EXPECT_FALSE(I->hasNoSignedWrap());
374 }
375 
376 // Check that SCEV does not save the SCEV -> V
377 // mapping of SCEV differ from V in NUW flag.
378 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNUW) {
379   /*
380    * Create the following code:
381    * func(i64 %a)
382    * entry:
383    *  %s1 = add i64 %a, -1
384    *  %s2 = add nuw i64 %a, -1
385    *  br label %exit
386    * exit:
387    *  ret %s
388    */
389 
390   // Create a module.
391   Module M("SCEVCacheNUW", Context);
392 
393   Type *T_int64 = Type::getInt64Ty(Context);
394 
395   FunctionType *FTy =
396       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
397   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
398   Argument *Arg = &*F->arg_begin();
399   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
400 
401   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
402   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
403 
404   IRBuilder<> Builder(Entry);
405   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
406   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
407   S2->setHasNoUnsignedWrap(true);
408   Builder.CreateBr(Exit);
409 
410   Builder.SetInsertPoint(Exit);
411   auto *R = cast<Instruction>(Builder.CreateRetVoid());
412 
413   ScalarEvolution SE = buildSE(*F);
414   // Get S2 first to move it to cache.
415   const SCEV *SC2 = SE.getSCEV(S2);
416   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
417   // Now get S1.
418   const SCEV *SC1 = SE.getSCEV(S1);
419   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
420   // Expand for S1, it should use S1 not S2 in spite S2
421   // first in the cache.
422   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
423   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
424   EXPECT_FALSE(I->hasNoUnsignedWrap());
425 }
426 
427 // Check that SCEV does not save the SCEV -> V
428 // mapping of SCEV differ from V in NSW flag.
429 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNSW) {
430   /*
431    * Create the following code:
432    * func(i64 %a)
433    * entry:
434    *  %s1 = add i64 %a, -1
435    *  %s2 = add nsw i64 %a, -1
436    *  br label %exit
437    * exit:
438    *  ret %s
439    */
440 
441   // Create a module.
442   Module M("SCEVCacheNUW", Context);
443 
444   Type *T_int64 = Type::getInt64Ty(Context);
445 
446   FunctionType *FTy =
447       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
448   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
449   Argument *Arg = &*F->arg_begin();
450   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
451 
452   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
453   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
454 
455   IRBuilder<> Builder(Entry);
456   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
457   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
458   S2->setHasNoSignedWrap(true);
459   Builder.CreateBr(Exit);
460 
461   Builder.SetInsertPoint(Exit);
462   auto *R = cast<Instruction>(Builder.CreateRetVoid());
463 
464   ScalarEvolution SE = buildSE(*F);
465   // Get S2 first to move it to cache.
466   const SCEV *SC2 = SE.getSCEV(S2);
467   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
468   // Now get S1.
469   const SCEV *SC1 = SE.getSCEV(S1);
470   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
471   // Expand for S1, it should use S1 not S2 in spite S2
472   // first in the cache.
473   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
474   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
475   EXPECT_FALSE(I->hasNoSignedWrap());
476 }
477 
478 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandInsertCanonicalIV) {
479   LLVMContext C;
480   SMDiagnostic Err;
481 
482   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
483   // SCEVExpander will insert one.
484   auto TestNoCanonicalIV =
485       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)>
486               GetAddRec) {
487         std::unique_ptr<Module> M = parseAssemblyString(
488             "define i32 @test(i32 %limit) { "
489             "entry: "
490             "  br label %loop "
491             "loop: "
492             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
493             "  %i.inc = add nsw i32 %i, 1 "
494             "  %cont = icmp slt i32 %i.inc, %limit "
495             "  br i1 %cont, label %loop, label %exit "
496             "exit: "
497             "  ret i32 %i.inc "
498             "}",
499             Err, C);
500 
501         assert(M && "Could not parse module?");
502         assert(!verifyModule(*M) && "Must have been well formed!");
503 
504         runWithSE(
505             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
506               auto &I = GetInstByName(F, "i");
507               auto *Loop = LI.getLoopFor(I.getParent());
508               EXPECT_FALSE(Loop->getCanonicalInductionVariable());
509 
510               auto *AR = GetAddRec(SE, Loop);
511               unsigned ExpectedCanonicalIVWidth =
512                   SE.getTypeSizeInBits(AR->getType());
513 
514               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
515               auto *InsertAt = I.getNextNode();
516               Exp.expandCodeFor(AR, nullptr, InsertAt);
517               PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
518               unsigned CanonicalIVBitWidth =
519                   cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
520               EXPECT_EQ(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
521             });
522       };
523 
524   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
525   // which is narrower than addrec type.
526   // SCEVExpander will insert a canonical IV of a wider type to expand the
527   // addrec.
528   auto TestNarrowCanonicalIV = [&](std::function<const SCEV *(
529                                        ScalarEvolution & SE, Loop * L)>
530                                        GetAddRec) {
531     std::unique_ptr<Module> M = parseAssemblyString(
532         "define i32 @test(i32 %limit) { "
533         "entry: "
534         "  br label %loop "
535         "loop: "
536         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
537         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
538         "  %i.inc = add nsw i32 %i, 1 "
539         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
540         "  %cont = icmp slt i32 %i.inc, %limit "
541         "  br i1 %cont, label %loop, label %exit "
542         "exit: "
543         "  ret i32 %i.inc "
544         "}",
545         Err, C);
546 
547     assert(M && "Could not parse module?");
548     assert(!verifyModule(*M) && "Must have been well formed!");
549 
550     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
551       auto &I = GetInstByName(F, "i");
552 
553       auto *LoopHeaderBB = I.getParent();
554       auto *Loop = LI.getLoopFor(LoopHeaderBB);
555       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
556       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
557 
558       auto *AR = GetAddRec(SE, Loop);
559 
560       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
561       unsigned CanonicalIVBitWidth =
562           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
563       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
564 
565       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
566       auto *InsertAt = I.getNextNode();
567       Exp.expandCodeFor(AR, nullptr, InsertAt);
568 
569       // Loop over all of the PHI nodes, looking for the new canonical indvar.
570       PHINode *NewCanonicalIV = nullptr;
571       for (BasicBlock::iterator i = LoopHeaderBB->begin(); isa<PHINode>(i);
572            ++i) {
573         PHINode *PN = cast<PHINode>(i);
574         if (PN == &I || PN == CanonicalIV)
575           continue;
576         // We expect that the only PHI added is the new canonical IV
577         EXPECT_FALSE(NewCanonicalIV);
578         NewCanonicalIV = PN;
579       }
580 
581       // Check that NewCanonicalIV is a canonical IV, i.e {0,+,1}
582       BasicBlock *Incoming = nullptr, *Backedge = nullptr;
583       EXPECT_TRUE(Loop->getIncomingAndBackEdge(Incoming, Backedge));
584       auto *Start = NewCanonicalIV->getIncomingValueForBlock(Incoming);
585       EXPECT_TRUE(isa<ConstantInt>(Start));
586       EXPECT_TRUE(dyn_cast<ConstantInt>(Start)->isZero());
587       auto *Next = NewCanonicalIV->getIncomingValueForBlock(Backedge);
588       EXPECT_TRUE(isa<BinaryOperator>(Next));
589       auto *NextBinOp = dyn_cast<BinaryOperator>(Next);
590       EXPECT_EQ(NextBinOp->getOpcode(), Instruction::Add);
591       EXPECT_EQ(NextBinOp->getOperand(0), NewCanonicalIV);
592       auto *Step = NextBinOp->getOperand(1);
593       EXPECT_TRUE(isa<ConstantInt>(Step));
594       EXPECT_TRUE(dyn_cast<ConstantInt>(Step)->isOne());
595 
596       unsigned NewCanonicalIVBitWidth =
597           cast<IntegerType>(NewCanonicalIV->getType())->getBitWidth();
598       EXPECT_EQ(NewCanonicalIVBitWidth, ExpectedCanonicalIVWidth);
599     });
600   };
601 
602   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
603   // of addrec width.
604   // To expand the addrec SCEVExpander should use the existing canonical IV.
605   auto TestMatchingCanonicalIV =
606       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)> GetAddRec,
607           unsigned ARBitWidth) {
608         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
609         std::unique_ptr<Module> M = parseAssemblyString(
610             "define i32 @test(i32 %limit) { "
611             "entry: "
612             "  br label %loop "
613             "loop: "
614             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
615             "  %canonical.iv = phi " +
616                 ARBitWidthTypeStr +
617                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
618                 "  %i.inc = add nsw i32 %i, 1 "
619                 "  %canonical.iv.inc = add " +
620                 ARBitWidthTypeStr +
621                 " %canonical.iv, 1 "
622                 "  %cont = icmp slt i32 %i.inc, %limit "
623                 "  br i1 %cont, label %loop, label %exit "
624                 "exit: "
625                 "  ret i32 %i.inc "
626                 "}",
627             Err, C);
628 
629         assert(M && "Could not parse module?");
630         assert(!verifyModule(*M) && "Must have been well formed!");
631 
632         runWithSE(
633             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
634               auto &I = GetInstByName(F, "i");
635               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
636 
637               auto *LoopHeaderBB = I.getParent();
638               auto *Loop = LI.getLoopFor(LoopHeaderBB);
639               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
640               unsigned CanonicalIVBitWidth =
641                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
642 
643               auto *AR = GetAddRec(SE, Loop);
644               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
645               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
646 
647               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
648               auto *InsertAt = I.getNextNode();
649               Exp.expandCodeFor(AR, nullptr, InsertAt);
650 
651               // Loop over all of the PHI nodes, looking if a new canonical
652               // indvar was introduced.
653               PHINode *NewCanonicalIV = nullptr;
654               for (BasicBlock::iterator i = LoopHeaderBB->begin();
655                    isa<PHINode>(i); ++i) {
656                 PHINode *PN = cast<PHINode>(i);
657                 if (PN == &I || PN == &CanonicalIV)
658                   continue;
659                 NewCanonicalIV = PN;
660               }
661               EXPECT_FALSE(NewCanonicalIV);
662             });
663       };
664 
665   unsigned ARBitWidth = 16;
666   Type *ARType = IntegerType::get(C, ARBitWidth);
667 
668   // Expand {5,+,1}
669   auto GetAR2 = [&](ScalarEvolution &SE, Loop *L) -> const SCEV * {
670     return SE.getAddRecExpr(SE.getConstant(APInt(ARBitWidth, 5)),
671                             SE.getOne(ARType), L, SCEV::FlagAnyWrap);
672   };
673   TestNoCanonicalIV(GetAR2);
674   TestNarrowCanonicalIV(GetAR2);
675   TestMatchingCanonicalIV(GetAR2, ARBitWidth);
676 }
677 
678 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderShlNSW) {
679 
680   auto checkOneCase = [this](std::string &&str) {
681     LLVMContext C;
682     SMDiagnostic Err;
683     std::unique_ptr<Module> M = parseAssemblyString(str, Err, C);
684 
685     assert(M && "Could not parse module?");
686     assert(!verifyModule(*M) && "Must have been well formed!");
687 
688     Function *F = M->getFunction("f");
689     ASSERT_NE(F, nullptr) << "Could not find function 'f'";
690 
691     BasicBlock &Entry = F->getEntryBlock();
692     LoadInst *Load = cast<LoadInst>(&Entry.front());
693     BinaryOperator *And = cast<BinaryOperator>(*Load->user_begin());
694 
695     ScalarEvolution SE = buildSE(*F);
696     const SCEV *AndSCEV = SE.getSCEV(And);
697     EXPECT_TRUE(isa<SCEVMulExpr>(AndSCEV));
698     EXPECT_TRUE(cast<SCEVMulExpr>(AndSCEV)->hasNoSignedWrap());
699 
700     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
701     auto *I = cast<Instruction>(Exp.expandCodeFor(AndSCEV, nullptr, And));
702     EXPECT_EQ(I->getOpcode(), Instruction::Shl);
703     EXPECT_FALSE(I->hasNoSignedWrap());
704   };
705 
706   checkOneCase("define void @f(i16* %arrayidx) { "
707                "  %1 = load i16, i16* %arrayidx "
708                "  %2 = and i16 %1, -32768 "
709                "  ret void "
710                "} ");
711 
712   checkOneCase("define void @f(i8* %arrayidx) { "
713                "  %1 = load i8, i8* %arrayidx "
714                "  %2 = and i8 %1, -128 "
715                "  ret void "
716                "} ");
717 }
718 
719 // Test expansion of nested addrecs in CanonicalMode.
720 // Expanding nested addrecs in canonical mode requiers a canonical IV of a
721 // type wider than the type of the addrec itself. Currently, SCEVExpander
722 // just falls back to literal mode for nested addrecs.
723 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandNonAffineAddRec) {
724   LLVMContext C;
725   SMDiagnostic Err;
726 
727   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
728   auto TestNoCanonicalIV =
729       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
730               GetAddRec) {
731         std::unique_ptr<Module> M = parseAssemblyString(
732             "define i32 @test(i32 %limit) { "
733             "entry: "
734             "  br label %loop "
735             "loop: "
736             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
737             "  %i.inc = add nsw i32 %i, 1 "
738             "  %cont = icmp slt i32 %i.inc, %limit "
739             "  br i1 %cont, label %loop, label %exit "
740             "exit: "
741             "  ret i32 %i.inc "
742             "}",
743             Err, C);
744 
745         assert(M && "Could not parse module?");
746         assert(!verifyModule(*M) && "Must have been well formed!");
747 
748         runWithSE(*M, "test",
749                   [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
750                     auto &I = GetInstByName(F, "i");
751                     auto *Loop = LI.getLoopFor(I.getParent());
752                     EXPECT_FALSE(Loop->getCanonicalInductionVariable());
753 
754                     auto *AR = GetAddRec(SE, Loop);
755                     EXPECT_FALSE(AR->isAffine());
756 
757                     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
758                     auto *InsertAt = I.getNextNode();
759                     Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
760                     const SCEV *ExpandedAR = SE.getSCEV(V);
761                     // Check that the expansion happened literally.
762                     EXPECT_EQ(AR, ExpandedAR);
763                   });
764       };
765 
766   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
767   // which is narrower than addrec type.
768   auto TestNarrowCanonicalIV = [&](std::function<const SCEVAddRecExpr *(
769                                        ScalarEvolution & SE, Loop * L)>
770                                        GetAddRec) {
771     std::unique_ptr<Module> M = parseAssemblyString(
772         "define i32 @test(i32 %limit) { "
773         "entry: "
774         "  br label %loop "
775         "loop: "
776         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
777         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
778         "  %i.inc = add nsw i32 %i, 1 "
779         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
780         "  %cont = icmp slt i32 %i.inc, %limit "
781         "  br i1 %cont, label %loop, label %exit "
782         "exit: "
783         "  ret i32 %i.inc "
784         "}",
785         Err, C);
786 
787     assert(M && "Could not parse module?");
788     assert(!verifyModule(*M) && "Must have been well formed!");
789 
790     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
791       auto &I = GetInstByName(F, "i");
792 
793       auto *LoopHeaderBB = I.getParent();
794       auto *Loop = LI.getLoopFor(LoopHeaderBB);
795       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
796       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
797 
798       auto *AR = GetAddRec(SE, Loop);
799       EXPECT_FALSE(AR->isAffine());
800 
801       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
802       unsigned CanonicalIVBitWidth =
803           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
804       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
805 
806       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
807       auto *InsertAt = I.getNextNode();
808       Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
809       const SCEV *ExpandedAR = SE.getSCEV(V);
810       // Check that the expansion happened literally.
811       EXPECT_EQ(AR, ExpandedAR);
812     });
813   };
814 
815   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
816   // of addrec width.
817   auto TestMatchingCanonicalIV =
818       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
819               GetAddRec,
820           unsigned ARBitWidth) {
821         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
822         std::unique_ptr<Module> M = parseAssemblyString(
823             "define i32 @test(i32 %limit) { "
824             "entry: "
825             "  br label %loop "
826             "loop: "
827             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
828             "  %canonical.iv = phi " +
829                 ARBitWidthTypeStr +
830                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
831                 "  %i.inc = add nsw i32 %i, 1 "
832                 "  %canonical.iv.inc = add " +
833                 ARBitWidthTypeStr +
834                 " %canonical.iv, 1 "
835                 "  %cont = icmp slt i32 %i.inc, %limit "
836                 "  br i1 %cont, label %loop, label %exit "
837                 "exit: "
838                 "  ret i32 %i.inc "
839                 "}",
840             Err, C);
841 
842         assert(M && "Could not parse module?");
843         assert(!verifyModule(*M) && "Must have been well formed!");
844 
845         runWithSE(
846             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
847               auto &I = GetInstByName(F, "i");
848               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
849 
850               auto *LoopHeaderBB = I.getParent();
851               auto *Loop = LI.getLoopFor(LoopHeaderBB);
852               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
853               unsigned CanonicalIVBitWidth =
854                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
855 
856               auto *AR = GetAddRec(SE, Loop);
857               EXPECT_FALSE(AR->isAffine());
858               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
859               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
860 
861               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
862               auto *InsertAt = I.getNextNode();
863               Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
864               const SCEV *ExpandedAR = SE.getSCEV(V);
865               // Check that the expansion happened literally.
866               EXPECT_EQ(AR, ExpandedAR);
867             });
868       };
869 
870   unsigned ARBitWidth = 16;
871   Type *ARType = IntegerType::get(C, ARBitWidth);
872 
873   // Expand {5,+,1,+,1}
874   auto GetAR3 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
875     SmallVector<const SCEV *, 3> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
876                                         SE.getOne(ARType), SE.getOne(ARType)};
877     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
878   };
879   TestNoCanonicalIV(GetAR3);
880   TestNarrowCanonicalIV(GetAR3);
881   TestMatchingCanonicalIV(GetAR3, ARBitWidth);
882 
883   // Expand {5,+,1,+,1,+,1}
884   auto GetAR4 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
885     SmallVector<const SCEV *, 4> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
886                                         SE.getOne(ARType), SE.getOne(ARType),
887                                         SE.getOne(ARType)};
888     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
889   };
890   TestNoCanonicalIV(GetAR4);
891   TestNarrowCanonicalIV(GetAR4);
892   TestMatchingCanonicalIV(GetAR4, ARBitWidth);
893 
894   // Expand {5,+,1,+,1,+,1,+,1}
895   auto GetAR5 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
896     SmallVector<const SCEV *, 5> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
897                                         SE.getOne(ARType), SE.getOne(ARType),
898                                         SE.getOne(ARType), SE.getOne(ARType)};
899     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
900   };
901   TestNoCanonicalIV(GetAR5);
902   TestNarrowCanonicalIV(GetAR5);
903   TestMatchingCanonicalIV(GetAR5, ARBitWidth);
904 }
905 
906 TEST_F(ScalarEvolutionExpanderTest, ExpandNonIntegralPtrWithNullBase) {
907   LLVMContext C;
908   SMDiagnostic Err;
909 
910   std::unique_ptr<Module> M =
911       parseAssemblyString("target datalayout = "
912                           "\"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:"
913                           "128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2\""
914                           "define ptr addrspace(1) @test(i64 %offset) { "
915                           "  %ptr = getelementptr inbounds float, ptr "
916                           "addrspace(1) null, i64 %offset"
917                           "  ret ptr addrspace(1) %ptr"
918                           "}",
919                           Err, C);
920 
921   assert(M && "Could not parse module?");
922   assert(!verifyModule(*M) && "Must have been well formed!");
923 
924   runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
925     auto &I = GetInstByName(F, "ptr");
926     auto PtrPlus1 =
927         SE.getAddExpr(SE.getSCEV(&I), SE.getConstant(I.getType(), 1));
928     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
929 
930     Value *V = Exp.expandCodeFor(PtrPlus1, I.getType(), &I);
931     I.replaceAllUsesWith(V);
932 
933     // Check that the expander created:
934     // define ptr addrspace(1) @test(i64 %off) {
935     //   %1 = shl i64 %offset, 2
936     //   %2 = add nuw nsw i64 %1, 1
937     //   %uglygep = getelementptr i8, ptr addrspace(1) null, i64 %2
938     //   %ptr = getelementptr inbounds float, ptr addrspace(1) null, i64 %off
939     //   ret ptr addrspace(1) %uglygep
940     // }
941 
942     Value *Offset = &*F.arg_begin();
943     auto *GEP = dyn_cast<GetElementPtrInst>(V);
944     EXPECT_TRUE(GEP);
945     EXPECT_TRUE(cast<Constant>(GEP->getPointerOperand())->isNullValue());
946     EXPECT_EQ(GEP->getNumOperands(), 2U);
947     EXPECT_TRUE(match(
948         GEP->getOperand(1),
949         m_Add(m_Shl(m_Specific(Offset), m_SpecificInt(2)), m_SpecificInt(1))));
950     EXPECT_EQ(cast<PointerType>(GEP->getPointerOperand()->getType())
951                   ->getAddressSpace(),
952               cast<PointerType>(I.getType())->getAddressSpace());
953     EXPECT_FALSE(verifyFunction(F, &errs()));
954   });
955 }
956 
957 TEST_F(ScalarEvolutionExpanderTest, GEPFlags) {
958   LLVMContext C;
959   SMDiagnostic Err;
960   StringRef ModStr = R"(
961   define void @f(ptr %p, i64 %x) {
962     %gep_inbounds = getelementptr inbounds i8, ptr %p, i64 %x
963     ret void
964   })";
965   std::unique_ptr<Module> M = parseAssemblyString(ModStr, Err, C);
966 
967   assert(M && "Could not parse module?");
968   assert(!verifyModule(*M) && "Must have been well formed!");
969 
970   Function *F = M->getFunction("f");
971   ASSERT_NE(F, nullptr) << "Could not find function 'f'";
972   BasicBlock &Entry = F->getEntryBlock();
973   auto *GEP = cast<GetElementPtrInst>(&Entry.front());
974 
975   ScalarEvolution SE = buildSE(*F);
976   const SCEV *Ptr = SE.getSCEV(F->getArg(0));
977   const SCEV *X = SE.getSCEV(F->getArg(1));
978   const SCEV *PtrX = SE.getAddExpr(Ptr, X);
979 
980   SCEVExpander Exp(SE, M->getDataLayout(), "expander");
981   auto *I = cast<Instruction>(
982       Exp.expandCodeFor(PtrX, nullptr, Entry.getTerminator()));
983   // Check that the GEP is reused, but the inbounds flag cleared. We don't
984   // know that the newly introduced use is inbounds.
985   EXPECT_EQ(I, GEP);
986   EXPECT_EQ(GEP->getNoWrapFlags(), GEPNoWrapFlags::none());
987 }
988 
989 } // end namespace llvm
990