xref: /llvm-project/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp (revision f75564ad4e4799465cf14b96f761e3fae13f6976)
1 //=== ScalarEvolutionExpanderTest.cpp - ScalarEvolutionExpander unit tests ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Verifier.h"
24 #include "gtest/gtest.h"
25 
26 namespace llvm {
27 
28 // We use this fixture to ensure that we clean up ScalarEvolution before
29 // deleting the PassManager.
30 class ScalarEvolutionExpanderTest : public testing::Test {
31 protected:
32   LLVMContext Context;
33   Module M;
34   TargetLibraryInfoImpl TLII;
35   TargetLibraryInfo TLI;
36 
37   std::unique_ptr<AssumptionCache> AC;
38   std::unique_ptr<DominatorTree> DT;
39   std::unique_ptr<LoopInfo> LI;
40 
41   ScalarEvolutionExpanderTest() : M("", Context), TLII(), TLI(TLII) {}
42 
43   ScalarEvolution buildSE(Function &F) {
44     AC.reset(new AssumptionCache(F));
45     DT.reset(new DominatorTree(F));
46     LI.reset(new LoopInfo(*DT));
47     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
48   }
49 
50   void runWithSE(
51       Module &M, StringRef FuncName,
52       function_ref<void(Function &F, LoopInfo &LI, ScalarEvolution &SE)> Test) {
53     auto *F = M.getFunction(FuncName);
54     ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
55     ScalarEvolution SE = buildSE(*F);
56     Test(*F, *LI, SE);
57   }
58 };
59 
60 static Instruction &GetInstByName(Function &F, StringRef Name) {
61   for (auto &I : instructions(F))
62     if (I.getName() == Name)
63       return I;
64   llvm_unreachable("Could not find instructions!");
65 }
66 
67 TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
68   // It is to test the fix for PR30213. It exercises the branch in scev
69   // expansion when the value in ValueOffsetPair is a ptr and the offset
70   // is not divisible by the elem type size of value.
71   auto *I8Ty = Type::getInt8Ty(Context);
72   auto *I8PtrTy = Type::getInt8PtrTy(Context);
73   auto *I32Ty = Type::getInt32Ty(Context);
74   auto *I32PtrTy = Type::getInt32PtrTy(Context);
75   FunctionType *FTy =
76       FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
77   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
78   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
79   BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
80   BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
81   BranchInst::Create(LoopBB, EntryBB);
82   ReturnInst::Create(Context, nullptr, ExitBB);
83 
84   // loop:                            ; preds = %loop, %entry
85   //   %alloca = alloca i32
86   //   %gep0 = getelementptr i32, i32* %alloca, i32 1
87   //   %bitcast1 = bitcast i32* %gep0 to i8*
88   //   %gep1 = getelementptr i8, i8* %bitcast1, i32 1
89   //   %gep2 = getelementptr i8, i8* undef, i32 1
90   //   %cmp = icmp ult i8* undef, %bitcast1
91   //   %select = select i1 %cmp, i8* %gep1, i8* %gep2
92   //   %bitcast2 = bitcast i8* %select to i32*
93   //   br i1 undef, label %loop, label %exit
94 
95   const DataLayout &DL = F->getParent()->getDataLayout();
96   BranchInst *Br = BranchInst::Create(
97       LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), LoopBB);
98   AllocaInst *Alloca =
99       new AllocaInst(I32Ty, DL.getAllocaAddrSpace(), "alloca", Br);
100   ConstantInt *Ci32 = ConstantInt::get(Context, APInt(32, 1));
101   GetElementPtrInst *Gep0 =
102       GetElementPtrInst::Create(I32Ty, Alloca, Ci32, "gep0", Br);
103   CastInst *CastA =
104       CastInst::CreateBitOrPointerCast(Gep0, I8PtrTy, "bitcast1", Br);
105   GetElementPtrInst *Gep1 =
106       GetElementPtrInst::Create(I8Ty, CastA, Ci32, "gep1", Br);
107   GetElementPtrInst *Gep2 = GetElementPtrInst::Create(
108       I8Ty, UndefValue::get(I8PtrTy), Ci32, "gep2", Br);
109   CmpInst *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
110                                  UndefValue::get(I8PtrTy), CastA, "cmp", Br);
111   SelectInst *Sel = SelectInst::Create(Cmp, Gep1, Gep2, "select", Br);
112   CastInst *CastB =
113       CastInst::CreateBitOrPointerCast(Sel, I32PtrTy, "bitcast2", Br);
114 
115   ScalarEvolution SE = buildSE(*F);
116   auto *S = SE.getSCEV(CastB);
117   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
118   Value *V =
119       Exp.expandCodeFor(cast<SCEVAddExpr>(S)->getOperand(1), nullptr, Br);
120 
121   // Expect the expansion code contains:
122   //   %0 = bitcast i32* %bitcast2 to i8*
123   //   %uglygep = getelementptr i8, i8* %0, i64 -1
124   //   %1 = bitcast i8* %uglygep to i32*
125   EXPECT_TRUE(isa<BitCastInst>(V));
126   Instruction *Gep = cast<Instruction>(V)->getPrevNode();
127   EXPECT_TRUE(isa<GetElementPtrInst>(Gep));
128   EXPECT_TRUE(isa<ConstantInt>(Gep->getOperand(1)));
129   EXPECT_EQ(cast<ConstantInt>(Gep->getOperand(1))->getSExtValue(), -1);
130   EXPECT_TRUE(isa<BitCastInst>(Gep->getPrevNode()));
131 }
132 
133 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
134 TEST_F(ScalarEvolutionExpanderTest, SCEVZeroExtendExprNonIntegral) {
135   /*
136    * Create the following code:
137    * func(i64 addrspace(10)* %arg)
138    * top:
139    *  br label %L.ph
140    * L.ph:
141    *  br label %L
142    * L:
143    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
144    *  %add = add i64 %phi2, 1
145    *  br i1 undef, label %post, label %L2
146    * post:
147    *  %gepbase = getelementptr i64 addrspace(10)* %arg, i64 1
148    *  #= %gep = getelementptr i64 addrspace(10)* %gepbase, i64 %add =#
149    *  ret void
150    *
151    * We will create the appropriate SCEV expression for %gep and expand it,
152    * then check that no inttoptr/ptrtoint instructions got inserted.
153    */
154 
155   // Create a module with non-integral pointers in it's datalayout
156   Module NIM("nonintegral", Context);
157   std::string DataLayout = M.getDataLayoutStr();
158   if (!DataLayout.empty())
159     DataLayout += "-";
160   DataLayout += "ni:10";
161   NIM.setDataLayout(DataLayout);
162 
163   Type *T_int1 = Type::getInt1Ty(Context);
164   Type *T_int64 = Type::getInt64Ty(Context);
165   Type *T_pint64 = T_int64->getPointerTo(10);
166 
167   FunctionType *FTy =
168       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
169   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
170 
171   Argument *Arg = &*F->arg_begin();
172 
173   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
174   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
175   BasicBlock *L = BasicBlock::Create(Context, "L", F);
176   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
177 
178   IRBuilder<> Builder(Top);
179   Builder.CreateBr(LPh);
180 
181   Builder.SetInsertPoint(LPh);
182   Builder.CreateBr(L);
183 
184   Builder.SetInsertPoint(L);
185   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
186   Value *Add = Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add");
187   Builder.CreateCondBr(UndefValue::get(T_int1), L, Post);
188   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
189   Phi->addIncoming(Add, L);
190 
191   Builder.SetInsertPoint(Post);
192   Value *GepBase =
193       Builder.CreateGEP(T_int64, Arg, ConstantInt::get(T_int64, 1));
194   Instruction *Ret = Builder.CreateRetVoid();
195 
196   ScalarEvolution SE = buildSE(*F);
197   auto *AddRec =
198       SE.getAddRecExpr(SE.getUnknown(GepBase), SE.getConstant(T_int64, 1),
199                        LI->getLoopFor(L), SCEV::FlagNUW);
200 
201   SCEVExpander Exp(SE, NIM.getDataLayout(), "expander");
202   Exp.disableCanonicalMode();
203   Exp.expandCodeFor(AddRec, T_pint64, Ret);
204 
205   // Make sure none of the instructions inserted were inttoptr/ptrtoint.
206   // The verifier will check this.
207   EXPECT_FALSE(verifyFunction(*F, &errs()));
208 }
209 
210 // Check that we can correctly identify the points at which the SCEV of the
211 // AddRec can be expanded.
212 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
213   /*
214    * Create the following code:
215    * func(i64 addrspace(10)* %arg)
216    * top:
217    *  br label %L.ph
218    * L.ph:
219    *  br label %L
220    * L:
221    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
222    *  %add = add i64 %phi2, 1
223    *  %cond = icmp slt i64 %add, 1000; then becomes 2000.
224    *  br i1 %cond, label %post, label %L2
225    * post:
226    *  ret void
227    *
228    */
229 
230   // Create a module with non-integral pointers in it's datalayout
231   Module NIM("nonintegral", Context);
232   std::string DataLayout = M.getDataLayoutStr();
233   if (!DataLayout.empty())
234     DataLayout += "-";
235   DataLayout += "ni:10";
236   NIM.setDataLayout(DataLayout);
237 
238   Type *T_int64 = Type::getInt64Ty(Context);
239   Type *T_pint64 = T_int64->getPointerTo(10);
240 
241   FunctionType *FTy =
242       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
243   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
244 
245   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
246   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
247   BasicBlock *L = BasicBlock::Create(Context, "L", F);
248   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
249 
250   IRBuilder<> Builder(Top);
251   Builder.CreateBr(LPh);
252 
253   Builder.SetInsertPoint(LPh);
254   Builder.CreateBr(L);
255 
256   Builder.SetInsertPoint(L);
257   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
258   auto *Add = cast<Instruction>(
259       Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
260   auto *Limit = ConstantInt::get(T_int64, 1000);
261   auto *Cond = cast<Instruction>(
262       Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Limit, "cond"));
263   Builder.CreateCondBr(Cond, L, Post);
264   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
265   Phi->addIncoming(Add, L);
266 
267   Builder.SetInsertPoint(Post);
268   Instruction *Ret = Builder.CreateRetVoid();
269 
270   ScalarEvolution SE = buildSE(*F);
271   const SCEV *S = SE.getSCEV(Phi);
272   EXPECT_TRUE(isa<SCEVAddRecExpr>(S));
273   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
274   EXPECT_TRUE(AR->isAffine());
275   EXPECT_FALSE(isSafeToExpandAt(AR, Top->getTerminator(), SE));
276   EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE));
277   EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE));
278   EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE));
279 
280   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
281   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
282   Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
283   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
284 }
285 
286 // Check that SCEV expander does not use the nuw instruction
287 // for expansion.
288 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNUW) {
289   /*
290    * Create the following code:
291    * func(i64 %a)
292    * entry:
293    *   br false, label %exit, label %body
294    * body:
295    *  %s1 = add i64 %a, -1
296    *  br label %exit
297    * exit:
298    *  %s = add nuw i64 %a, -1
299    *  ret %s
300    */
301 
302   // Create a module.
303   Module M("SCEVExpanderNUW", Context);
304 
305   Type *T_int64 = Type::getInt64Ty(Context);
306 
307   FunctionType *FTy =
308       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
309   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
310   Argument *Arg = &*F->arg_begin();
311   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
312 
313   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
314   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
315   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
316 
317   IRBuilder<> Builder(Entry);
318   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
319   Builder.CreateCondBr(Cond, Exit, Body);
320 
321   Builder.SetInsertPoint(Body);
322   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
323   Builder.CreateBr(Exit);
324 
325   Builder.SetInsertPoint(Exit);
326   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
327   S2->setHasNoUnsignedWrap(true);
328   auto *R = cast<Instruction>(Builder.CreateRetVoid());
329 
330   ScalarEvolution SE = buildSE(*F);
331   const SCEV *S = SE.getSCEV(S1);
332   EXPECT_TRUE(isa<SCEVAddExpr>(S));
333   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
334   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
335   EXPECT_FALSE(I->hasNoUnsignedWrap());
336 }
337 
338 // Check that SCEV expander does not use the nsw instruction
339 // for expansion.
340 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNSW) {
341   /*
342    * Create the following code:
343    * func(i64 %a)
344    * entry:
345    *   br false, label %exit, label %body
346    * body:
347    *  %s1 = add i64 %a, -1
348    *  br label %exit
349    * exit:
350    *  %s = add nsw i64 %a, -1
351    *  ret %s
352    */
353 
354   // Create a module.
355   Module M("SCEVExpanderNSW", Context);
356 
357   Type *T_int64 = Type::getInt64Ty(Context);
358 
359   FunctionType *FTy =
360       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
361   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
362   Argument *Arg = &*F->arg_begin();
363   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
364 
365   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
366   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
367   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
368 
369   IRBuilder<> Builder(Entry);
370   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
371   Builder.CreateCondBr(Cond, Exit, Body);
372 
373   Builder.SetInsertPoint(Body);
374   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
375   Builder.CreateBr(Exit);
376 
377   Builder.SetInsertPoint(Exit);
378   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
379   S2->setHasNoSignedWrap(true);
380   auto *R = cast<Instruction>(Builder.CreateRetVoid());
381 
382   ScalarEvolution SE = buildSE(*F);
383   const SCEV *S = SE.getSCEV(S1);
384   EXPECT_TRUE(isa<SCEVAddExpr>(S));
385   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
386   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
387   EXPECT_FALSE(I->hasNoSignedWrap());
388 }
389 
390 // Check that SCEV does not save the SCEV -> V
391 // mapping of SCEV differ from V in NUW flag.
392 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNUW) {
393   /*
394    * Create the following code:
395    * func(i64 %a)
396    * entry:
397    *  %s1 = add i64 %a, -1
398    *  %s2 = add nuw i64 %a, -1
399    *  br label %exit
400    * exit:
401    *  ret %s
402    */
403 
404   // Create a module.
405   Module M("SCEVCacheNUW", Context);
406 
407   Type *T_int64 = Type::getInt64Ty(Context);
408 
409   FunctionType *FTy =
410       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
411   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
412   Argument *Arg = &*F->arg_begin();
413   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
414 
415   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
416   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
417 
418   IRBuilder<> Builder(Entry);
419   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
420   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
421   S2->setHasNoUnsignedWrap(true);
422   Builder.CreateBr(Exit);
423 
424   Builder.SetInsertPoint(Exit);
425   auto *R = cast<Instruction>(Builder.CreateRetVoid());
426 
427   ScalarEvolution SE = buildSE(*F);
428   // Get S2 first to move it to cache.
429   const SCEV *SC2 = SE.getSCEV(S2);
430   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
431   // Now get S1.
432   const SCEV *SC1 = SE.getSCEV(S1);
433   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
434   // Expand for S1, it should use S1 not S2 in spite S2
435   // first in the cache.
436   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
437   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
438   EXPECT_FALSE(I->hasNoUnsignedWrap());
439 }
440 
441 // Check that SCEV does not save the SCEV -> V
442 // mapping of SCEV differ from V in NSW flag.
443 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNSW) {
444   /*
445    * Create the following code:
446    * func(i64 %a)
447    * entry:
448    *  %s1 = add i64 %a, -1
449    *  %s2 = add nsw i64 %a, -1
450    *  br label %exit
451    * exit:
452    *  ret %s
453    */
454 
455   // Create a module.
456   Module M("SCEVCacheNUW", Context);
457 
458   Type *T_int64 = Type::getInt64Ty(Context);
459 
460   FunctionType *FTy =
461       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
462   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
463   Argument *Arg = &*F->arg_begin();
464   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
465 
466   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
467   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
468 
469   IRBuilder<> Builder(Entry);
470   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
471   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
472   S2->setHasNoSignedWrap(true);
473   Builder.CreateBr(Exit);
474 
475   Builder.SetInsertPoint(Exit);
476   auto *R = cast<Instruction>(Builder.CreateRetVoid());
477 
478   ScalarEvolution SE = buildSE(*F);
479   // Get S2 first to move it to cache.
480   const SCEV *SC2 = SE.getSCEV(S2);
481   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
482   // Now get S1.
483   const SCEV *SC1 = SE.getSCEV(S1);
484   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
485   // Expand for S1, it should use S1 not S2 in spite S2
486   // first in the cache.
487   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
488   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
489   EXPECT_FALSE(I->hasNoSignedWrap());
490 }
491 
492 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandInsertCanonicalIV) {
493   LLVMContext C;
494   SMDiagnostic Err;
495 
496   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
497   // SCEVExpander will insert one.
498   auto TestNoCanonicalIV =
499       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)>
500               GetAddRec) {
501         std::unique_ptr<Module> M = parseAssemblyString(
502             "define i32 @test(i32 %limit) { "
503             "entry: "
504             "  br label %loop "
505             "loop: "
506             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
507             "  %i.inc = add nsw i32 %i, 1 "
508             "  %cont = icmp slt i32 %i.inc, %limit "
509             "  br i1 %cont, label %loop, label %exit "
510             "exit: "
511             "  ret i32 %i.inc "
512             "}",
513             Err, C);
514 
515         assert(M && "Could not parse module?");
516         assert(!verifyModule(*M) && "Must have been well formed!");
517 
518         runWithSE(
519             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
520               auto &I = GetInstByName(F, "i");
521               auto *Loop = LI.getLoopFor(I.getParent());
522               EXPECT_FALSE(Loop->getCanonicalInductionVariable());
523 
524               auto *AR = GetAddRec(SE, Loop);
525               unsigned ExpectedCanonicalIVWidth =
526                   SE.getTypeSizeInBits(AR->getType());
527 
528               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
529               auto *InsertAt = I.getNextNode();
530               Exp.expandCodeFor(AR, nullptr, InsertAt);
531               PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
532               unsigned CanonicalIVBitWidth =
533                   cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
534               EXPECT_EQ(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
535             });
536       };
537 
538   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
539   // which is narrower than addrec type.
540   // SCEVExpander will insert a canonical IV of a wider type to expand the
541   // addrec.
542   auto TestNarrowCanonicalIV = [&](std::function<const SCEV *(
543                                        ScalarEvolution & SE, Loop * L)>
544                                        GetAddRec) {
545     std::unique_ptr<Module> M = parseAssemblyString(
546         "define i32 @test(i32 %limit) { "
547         "entry: "
548         "  br label %loop "
549         "loop: "
550         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
551         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
552         "  %i.inc = add nsw i32 %i, 1 "
553         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
554         "  %cont = icmp slt i32 %i.inc, %limit "
555         "  br i1 %cont, label %loop, label %exit "
556         "exit: "
557         "  ret i32 %i.inc "
558         "}",
559         Err, C);
560 
561     assert(M && "Could not parse module?");
562     assert(!verifyModule(*M) && "Must have been well formed!");
563 
564     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
565       auto &I = GetInstByName(F, "i");
566 
567       auto *LoopHeaderBB = I.getParent();
568       auto *Loop = LI.getLoopFor(LoopHeaderBB);
569       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
570       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
571 
572       auto *AR = GetAddRec(SE, Loop);
573 
574       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
575       unsigned CanonicalIVBitWidth =
576           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
577       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
578 
579       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
580       auto *InsertAt = I.getNextNode();
581       Exp.expandCodeFor(AR, nullptr, InsertAt);
582 
583       // Loop over all of the PHI nodes, looking for the new canonical indvar.
584       PHINode *NewCanonicalIV = nullptr;
585       for (BasicBlock::iterator i = LoopHeaderBB->begin(); isa<PHINode>(i);
586            ++i) {
587         PHINode *PN = cast<PHINode>(i);
588         if (PN == &I || PN == CanonicalIV)
589           continue;
590         // We expect that the only PHI added is the new canonical IV
591         EXPECT_FALSE(NewCanonicalIV);
592         NewCanonicalIV = PN;
593       }
594 
595       // Check that NewCanonicalIV is a canonical IV, i.e {0,+,1}
596       BasicBlock *Incoming = nullptr, *Backedge = nullptr;
597       EXPECT_TRUE(Loop->getIncomingAndBackEdge(Incoming, Backedge));
598       auto *Start = NewCanonicalIV->getIncomingValueForBlock(Incoming);
599       EXPECT_TRUE(isa<ConstantInt>(Start));
600       EXPECT_TRUE(dyn_cast<ConstantInt>(Start)->isZero());
601       auto *Next = NewCanonicalIV->getIncomingValueForBlock(Backedge);
602       EXPECT_TRUE(isa<BinaryOperator>(Next));
603       auto *NextBinOp = dyn_cast<BinaryOperator>(Next);
604       EXPECT_EQ(NextBinOp->getOpcode(), Instruction::Add);
605       EXPECT_EQ(NextBinOp->getOperand(0), NewCanonicalIV);
606       auto *Step = NextBinOp->getOperand(1);
607       EXPECT_TRUE(isa<ConstantInt>(Step));
608       EXPECT_TRUE(dyn_cast<ConstantInt>(Step)->isOne());
609 
610       unsigned NewCanonicalIVBitWidth =
611           cast<IntegerType>(NewCanonicalIV->getType())->getBitWidth();
612       EXPECT_EQ(NewCanonicalIVBitWidth, ExpectedCanonicalIVWidth);
613     });
614   };
615 
616   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
617   // of addrec width.
618   // To expand the addrec SCEVExpander should use the existing canonical IV.
619   auto TestMatchingCanonicalIV =
620       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)> GetAddRec,
621           unsigned ARBitWidth) {
622         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
623         std::unique_ptr<Module> M = parseAssemblyString(
624             "define i32 @test(i32 %limit) { "
625             "entry: "
626             "  br label %loop "
627             "loop: "
628             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
629             "  %canonical.iv = phi " +
630                 ARBitWidthTypeStr +
631                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
632                 "  %i.inc = add nsw i32 %i, 1 "
633                 "  %canonical.iv.inc = add " +
634                 ARBitWidthTypeStr +
635                 " %canonical.iv, 1 "
636                 "  %cont = icmp slt i32 %i.inc, %limit "
637                 "  br i1 %cont, label %loop, label %exit "
638                 "exit: "
639                 "  ret i32 %i.inc "
640                 "}",
641             Err, C);
642 
643         assert(M && "Could not parse module?");
644         assert(!verifyModule(*M) && "Must have been well formed!");
645 
646         runWithSE(
647             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
648               auto &I = GetInstByName(F, "i");
649               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
650 
651               auto *LoopHeaderBB = I.getParent();
652               auto *Loop = LI.getLoopFor(LoopHeaderBB);
653               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
654               unsigned CanonicalIVBitWidth =
655                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
656 
657               auto *AR = GetAddRec(SE, Loop);
658               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
659               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
660 
661               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
662               auto *InsertAt = I.getNextNode();
663               Exp.expandCodeFor(AR, nullptr, InsertAt);
664 
665               // Loop over all of the PHI nodes, looking if a new canonical
666               // indvar was introduced.
667               PHINode *NewCanonicalIV = nullptr;
668               for (BasicBlock::iterator i = LoopHeaderBB->begin();
669                    isa<PHINode>(i); ++i) {
670                 PHINode *PN = cast<PHINode>(i);
671                 if (PN == &I || PN == &CanonicalIV)
672                   continue;
673                 NewCanonicalIV = PN;
674               }
675               EXPECT_FALSE(NewCanonicalIV);
676             });
677       };
678 
679   unsigned ARBitWidth = 16;
680   Type *ARType = IntegerType::get(C, ARBitWidth);
681 
682   // Expand {5,+,1}
683   auto GetAR2 = [&](ScalarEvolution &SE, Loop *L) -> const SCEV * {
684     return SE.getAddRecExpr(SE.getConstant(APInt(ARBitWidth, 5)),
685                             SE.getOne(ARType), L, SCEV::FlagAnyWrap);
686   };
687   TestNoCanonicalIV(GetAR2);
688   TestNarrowCanonicalIV(GetAR2);
689   TestMatchingCanonicalIV(GetAR2, ARBitWidth);
690 }
691 
692 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderShlNSW) {
693 
694   auto checkOneCase = [this](std::string &&str) {
695     LLVMContext C;
696     SMDiagnostic Err;
697     std::unique_ptr<Module> M = parseAssemblyString(str, Err, C);
698 
699     assert(M && "Could not parse module?");
700     assert(!verifyModule(*M) && "Must have been well formed!");
701 
702     Function *F = M->getFunction("f");
703     ASSERT_NE(F, nullptr) << "Could not find function 'f'";
704 
705     BasicBlock &Entry = F->getEntryBlock();
706     LoadInst *Load = cast<LoadInst>(&Entry.front());
707     BinaryOperator *And = cast<BinaryOperator>(*Load->user_begin());
708 
709     ScalarEvolution SE = buildSE(*F);
710     const SCEV *AndSCEV = SE.getSCEV(And);
711     EXPECT_TRUE(isa<SCEVMulExpr>(AndSCEV));
712     EXPECT_TRUE(cast<SCEVMulExpr>(AndSCEV)->hasNoSignedWrap());
713 
714     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
715     auto *I = cast<Instruction>(Exp.expandCodeFor(AndSCEV, nullptr, And));
716     EXPECT_EQ(I->getOpcode(), Instruction::Shl);
717     EXPECT_FALSE(I->hasNoSignedWrap());
718   };
719 
720   checkOneCase("define void @f(i16* %arrayidx) { "
721                "  %1 = load i16, i16* %arrayidx "
722                "  %2 = and i16 %1, -32768 "
723                "  ret void "
724                "} ");
725 
726   checkOneCase("define void @f(i8* %arrayidx) { "
727                "  %1 = load i8, i8* %arrayidx "
728                "  %2 = and i8 %1, -128 "
729                "  ret void "
730                "} ");
731 }
732 
733 // Test expansion of nested addrecs in CanonicalMode.
734 // Expanding nested addrecs in canonical mode requiers a canonical IV of a
735 // type wider than the type of the addrec itself. Currently, SCEVExpander
736 // just falls back to literal mode for nested addrecs.
737 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandNonAffineAddRec) {
738   LLVMContext C;
739   SMDiagnostic Err;
740 
741   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
742   auto TestNoCanonicalIV =
743       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
744               GetAddRec) {
745         std::unique_ptr<Module> M = parseAssemblyString(
746             "define i32 @test(i32 %limit) { "
747             "entry: "
748             "  br label %loop "
749             "loop: "
750             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
751             "  %i.inc = add nsw i32 %i, 1 "
752             "  %cont = icmp slt i32 %i.inc, %limit "
753             "  br i1 %cont, label %loop, label %exit "
754             "exit: "
755             "  ret i32 %i.inc "
756             "}",
757             Err, C);
758 
759         assert(M && "Could not parse module?");
760         assert(!verifyModule(*M) && "Must have been well formed!");
761 
762         runWithSE(*M, "test",
763                   [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
764                     auto &I = GetInstByName(F, "i");
765                     auto *Loop = LI.getLoopFor(I.getParent());
766                     EXPECT_FALSE(Loop->getCanonicalInductionVariable());
767 
768                     auto *AR = GetAddRec(SE, Loop);
769                     EXPECT_FALSE(AR->isAffine());
770 
771                     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
772                     auto *InsertAt = I.getNextNode();
773                     Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
774                     auto *ExpandedAR = SE.getSCEV(V);
775                     // Check that the expansion happened literally.
776                     EXPECT_EQ(AR, ExpandedAR);
777                   });
778       };
779 
780   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
781   // which is narrower than addrec type.
782   auto TestNarrowCanonicalIV = [&](std::function<const SCEVAddRecExpr *(
783                                        ScalarEvolution & SE, Loop * L)>
784                                        GetAddRec) {
785     std::unique_ptr<Module> M = parseAssemblyString(
786         "define i32 @test(i32 %limit) { "
787         "entry: "
788         "  br label %loop "
789         "loop: "
790         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
791         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
792         "  %i.inc = add nsw i32 %i, 1 "
793         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
794         "  %cont = icmp slt i32 %i.inc, %limit "
795         "  br i1 %cont, label %loop, label %exit "
796         "exit: "
797         "  ret i32 %i.inc "
798         "}",
799         Err, C);
800 
801     assert(M && "Could not parse module?");
802     assert(!verifyModule(*M) && "Must have been well formed!");
803 
804     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
805       auto &I = GetInstByName(F, "i");
806 
807       auto *LoopHeaderBB = I.getParent();
808       auto *Loop = LI.getLoopFor(LoopHeaderBB);
809       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
810       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
811 
812       auto *AR = GetAddRec(SE, Loop);
813       EXPECT_FALSE(AR->isAffine());
814 
815       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
816       unsigned CanonicalIVBitWidth =
817           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
818       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
819 
820       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
821       auto *InsertAt = I.getNextNode();
822       Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
823       auto *ExpandedAR = SE.getSCEV(V);
824       // Check that the expansion happened literally.
825       EXPECT_EQ(AR, ExpandedAR);
826     });
827   };
828 
829   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
830   // of addrec width.
831   auto TestMatchingCanonicalIV =
832       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
833               GetAddRec,
834           unsigned ARBitWidth) {
835         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
836         std::unique_ptr<Module> M = parseAssemblyString(
837             "define i32 @test(i32 %limit) { "
838             "entry: "
839             "  br label %loop "
840             "loop: "
841             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
842             "  %canonical.iv = phi " +
843                 ARBitWidthTypeStr +
844                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
845                 "  %i.inc = add nsw i32 %i, 1 "
846                 "  %canonical.iv.inc = add " +
847                 ARBitWidthTypeStr +
848                 " %canonical.iv, 1 "
849                 "  %cont = icmp slt i32 %i.inc, %limit "
850                 "  br i1 %cont, label %loop, label %exit "
851                 "exit: "
852                 "  ret i32 %i.inc "
853                 "}",
854             Err, C);
855 
856         assert(M && "Could not parse module?");
857         assert(!verifyModule(*M) && "Must have been well formed!");
858 
859         runWithSE(
860             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
861               auto &I = GetInstByName(F, "i");
862               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
863 
864               auto *LoopHeaderBB = I.getParent();
865               auto *Loop = LI.getLoopFor(LoopHeaderBB);
866               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
867               unsigned CanonicalIVBitWidth =
868                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
869 
870               auto *AR = GetAddRec(SE, Loop);
871               EXPECT_FALSE(AR->isAffine());
872               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
873               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
874 
875               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
876               auto *InsertAt = I.getNextNode();
877               Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
878               auto *ExpandedAR = SE.getSCEV(V);
879               // Check that the expansion happened literally.
880               EXPECT_EQ(AR, ExpandedAR);
881             });
882       };
883 
884   unsigned ARBitWidth = 16;
885   Type *ARType = IntegerType::get(C, ARBitWidth);
886 
887   // Expand {5,+,1,+,1}
888   auto GetAR3 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
889     SmallVector<const SCEV *, 3> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
890                                         SE.getOne(ARType), SE.getOne(ARType)};
891     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
892   };
893   TestNoCanonicalIV(GetAR3);
894   TestNarrowCanonicalIV(GetAR3);
895   TestMatchingCanonicalIV(GetAR3, ARBitWidth);
896 
897   // Expand {5,+,1,+,1,+,1}
898   auto GetAR4 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
899     SmallVector<const SCEV *, 4> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
900                                         SE.getOne(ARType), SE.getOne(ARType),
901                                         SE.getOne(ARType)};
902     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
903   };
904   TestNoCanonicalIV(GetAR4);
905   TestNarrowCanonicalIV(GetAR4);
906   TestMatchingCanonicalIV(GetAR4, ARBitWidth);
907 
908   // Expand {5,+,1,+,1,+,1,+,1}
909   auto GetAR5 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
910     SmallVector<const SCEV *, 5> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
911                                         SE.getOne(ARType), SE.getOne(ARType),
912                                         SE.getOne(ARType), SE.getOne(ARType)};
913     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
914   };
915   TestNoCanonicalIV(GetAR5);
916   TestNarrowCanonicalIV(GetAR5);
917   TestMatchingCanonicalIV(GetAR5, ARBitWidth);
918 }
919 
920 } // end namespace llvm
921