xref: /llvm-project/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp (revision dc2867576886247cbe351e7c63618c09ab6af808)
1 //=== ScalarEvolutionExpanderTest.cpp - ScalarEvolutionExpander unit tests ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Verifier.h"
24 #include "gtest/gtest.h"
25 
26 namespace llvm {
27 
28 // We use this fixture to ensure that we clean up ScalarEvolution before
29 // deleting the PassManager.
30 class ScalarEvolutionExpanderTest : public testing::Test {
31 protected:
32   LLVMContext Context;
33   Module M;
34   TargetLibraryInfoImpl TLII;
35   TargetLibraryInfo TLI;
36 
37   std::unique_ptr<AssumptionCache> AC;
38   std::unique_ptr<DominatorTree> DT;
39   std::unique_ptr<LoopInfo> LI;
40 
41   ScalarEvolutionExpanderTest() : M("", Context), TLII(), TLI(TLII) {}
42 
43   ScalarEvolution buildSE(Function &F) {
44     AC.reset(new AssumptionCache(F));
45     DT.reset(new DominatorTree(F));
46     LI.reset(new LoopInfo(*DT));
47     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
48   }
49 
50   void runWithSE(
51       Module &M, StringRef FuncName,
52       function_ref<void(Function &F, LoopInfo &LI, ScalarEvolution &SE)> Test) {
53     auto *F = M.getFunction(FuncName);
54     ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
55     ScalarEvolution SE = buildSE(*F);
56     Test(*F, *LI, SE);
57   }
58 };
59 
60 static Instruction &GetInstByName(Function &F, StringRef Name) {
61   for (auto &I : instructions(F))
62     if (I.getName() == Name)
63       return I;
64   llvm_unreachable("Could not find instructions!");
65 }
66 
67 TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
68   // It is to test the fix for PR30213. It exercises the branch in scev
69   // expansion when the value in ValueOffsetPair is a ptr and the offset
70   // is not divisible by the elem type size of value.
71   auto *I8Ty = Type::getInt8Ty(Context);
72   auto *I8PtrTy = Type::getInt8PtrTy(Context);
73   auto *I32Ty = Type::getInt32Ty(Context);
74   auto *I32PtrTy = Type::getInt32PtrTy(Context);
75   FunctionType *FTy =
76       FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
77   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
78   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
79   BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
80   BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
81   BranchInst::Create(LoopBB, EntryBB);
82   ReturnInst::Create(Context, nullptr, ExitBB);
83 
84   // loop:                            ; preds = %loop, %entry
85   //   %alloca = alloca i32
86   //   %gep0 = getelementptr i32, i32* %alloca, i32 1
87   //   %bitcast1 = bitcast i32* %gep0 to i8*
88   //   %gep1 = getelementptr i8, i8* %bitcast1, i32 1
89   //   %gep2 = getelementptr i8, i8* undef, i32 1
90   //   %cmp = icmp ult i8* undef, %bitcast1
91   //   %select = select i1 %cmp, i8* %gep1, i8* %gep2
92   //   %bitcast2 = bitcast i8* %select to i32*
93   //   br i1 undef, label %loop, label %exit
94 
95   const DataLayout &DL = F->getParent()->getDataLayout();
96   BranchInst *Br = BranchInst::Create(
97       LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), LoopBB);
98   AllocaInst *Alloca =
99       new AllocaInst(I32Ty, DL.getAllocaAddrSpace(), "alloca", Br);
100   ConstantInt *Ci32 = ConstantInt::get(Context, APInt(32, 1));
101   GetElementPtrInst *Gep0 =
102       GetElementPtrInst::Create(I32Ty, Alloca, Ci32, "gep0", Br);
103   CastInst *CastA =
104       CastInst::CreateBitOrPointerCast(Gep0, I8PtrTy, "bitcast1", Br);
105   GetElementPtrInst *Gep1 =
106       GetElementPtrInst::Create(I8Ty, CastA, Ci32, "gep1", Br);
107   GetElementPtrInst *Gep2 = GetElementPtrInst::Create(
108       I8Ty, UndefValue::get(I8PtrTy), Ci32, "gep2", Br);
109   CmpInst *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
110                                  UndefValue::get(I8PtrTy), CastA, "cmp", Br);
111   SelectInst *Sel = SelectInst::Create(Cmp, Gep1, Gep2, "select", Br);
112   CastInst *CastB =
113       CastInst::CreateBitOrPointerCast(Sel, I32PtrTy, "bitcast2", Br);
114 
115   ScalarEvolution SE = buildSE(*F);
116   auto *S = SE.getSCEV(CastB);
117   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
118   Value *V =
119       Exp.expandCodeFor(cast<SCEVAddExpr>(S)->getOperand(1), nullptr, Br);
120 
121   // Expect the expansion code contains:
122   //   %0 = bitcast i32* %bitcast2 to i8*
123   //   %uglygep = getelementptr i8, i8* %0, i64 -1
124   //   %1 = bitcast i8* %uglygep to i32*
125   EXPECT_TRUE(isa<BitCastInst>(V));
126   Instruction *Gep = cast<Instruction>(V)->getPrevNode();
127   EXPECT_TRUE(isa<GetElementPtrInst>(Gep));
128   EXPECT_TRUE(isa<ConstantInt>(Gep->getOperand(1)));
129   EXPECT_EQ(cast<ConstantInt>(Gep->getOperand(1))->getSExtValue(), -1);
130   EXPECT_TRUE(isa<BitCastInst>(Gep->getPrevNode()));
131 }
132 
133 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
134 TEST_F(ScalarEvolutionExpanderTest, SCEVZeroExtendExprNonIntegral) {
135   /*
136    * Create the following code:
137    * func(i64 addrspace(10)* %arg)
138    * top:
139    *  br label %L.ph
140    * L.ph:
141    *  br label %L
142    * L:
143    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
144    *  %add = add i64 %phi2, 1
145    *  br i1 undef, label %post, label %L2
146    * post:
147    *  %gepbase = getelementptr i64 addrspace(10)* %arg, i64 1
148    *  #= %gep = getelementptr i64 addrspace(10)* %gepbase, i64 %add =#
149    *  ret void
150    *
151    * We will create the appropriate SCEV expression for %gep and expand it,
152    * then check that no inttoptr/ptrtoint instructions got inserted.
153    */
154 
155   // Create a module with non-integral pointers in it's datalayout
156   Module NIM("nonintegral", Context);
157   std::string DataLayout = M.getDataLayoutStr();
158   if (!DataLayout.empty())
159     DataLayout += "-";
160   DataLayout += "ni:10";
161   NIM.setDataLayout(DataLayout);
162 
163   Type *T_int1 = Type::getInt1Ty(Context);
164   Type *T_int64 = Type::getInt64Ty(Context);
165   Type *T_pint64 = T_int64->getPointerTo(10);
166 
167   FunctionType *FTy =
168       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
169   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
170 
171   Argument *Arg = &*F->arg_begin();
172 
173   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
174   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
175   BasicBlock *L = BasicBlock::Create(Context, "L", F);
176   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
177 
178   IRBuilder<> Builder(Top);
179   Builder.CreateBr(LPh);
180 
181   Builder.SetInsertPoint(LPh);
182   Builder.CreateBr(L);
183 
184   Builder.SetInsertPoint(L);
185   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
186   Value *Add = Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add");
187   Builder.CreateCondBr(UndefValue::get(T_int1), L, Post);
188   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
189   Phi->addIncoming(Add, L);
190 
191   Builder.SetInsertPoint(Post);
192   Value *GepBase =
193       Builder.CreateGEP(T_int64, Arg, ConstantInt::get(T_int64, 1));
194   Instruction *Ret = Builder.CreateRetVoid();
195 
196   ScalarEvolution SE = buildSE(*F);
197   auto *AddRec =
198       SE.getAddRecExpr(SE.getUnknown(GepBase), SE.getConstant(T_int64, 1),
199                        LI->getLoopFor(L), SCEV::FlagNUW);
200 
201   SCEVExpander Exp(SE, NIM.getDataLayout(), "expander");
202   Exp.disableCanonicalMode();
203   Exp.expandCodeFor(AddRec, T_pint64, Ret);
204 
205   // Make sure none of the instructions inserted were inttoptr/ptrtoint.
206   // The verifier will check this.
207   EXPECT_FALSE(verifyFunction(*F, &errs()));
208 }
209 
210 // Check that we can correctly identify the points at which the SCEV of the
211 // AddRec can be expanded.
212 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
213   /*
214    * Create the following code:
215    * func(i64 addrspace(10)* %arg)
216    * top:
217    *  br label %L.ph
218    * L.ph:
219    *  br label %L
220    * L:
221    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
222    *  %add = add i64 %phi2, 1
223    *  %cond = icmp slt i64 %add, 1000; then becomes 2000.
224    *  br i1 %cond, label %post, label %L2
225    * post:
226    *  ret void
227    *
228    */
229 
230   // Create a module with non-integral pointers in it's datalayout
231   Module NIM("nonintegral", Context);
232   std::string DataLayout = M.getDataLayoutStr();
233   if (!DataLayout.empty())
234     DataLayout += "-";
235   DataLayout += "ni:10";
236   NIM.setDataLayout(DataLayout);
237 
238   Type *T_int64 = Type::getInt64Ty(Context);
239   Type *T_pint64 = T_int64->getPointerTo(10);
240 
241   FunctionType *FTy =
242       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
243   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
244 
245   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
246   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
247   BasicBlock *L = BasicBlock::Create(Context, "L", F);
248   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
249 
250   IRBuilder<> Builder(Top);
251   Builder.CreateBr(LPh);
252 
253   Builder.SetInsertPoint(LPh);
254   Builder.CreateBr(L);
255 
256   Builder.SetInsertPoint(L);
257   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
258   auto *Add = cast<Instruction>(
259       Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
260   auto *Limit = ConstantInt::get(T_int64, 1000);
261   auto *Cond = cast<Instruction>(
262       Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Limit, "cond"));
263   Builder.CreateCondBr(Cond, L, Post);
264   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
265   Phi->addIncoming(Add, L);
266 
267   Builder.SetInsertPoint(Post);
268   Builder.CreateRetVoid();
269 
270   ScalarEvolution SE = buildSE(*F);
271   const SCEV *S = SE.getSCEV(Phi);
272   EXPECT_TRUE(isa<SCEVAddRecExpr>(S));
273   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
274   EXPECT_TRUE(AR->isAffine());
275   EXPECT_FALSE(isSafeToExpandAt(AR, Top->getTerminator(), SE));
276   EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE));
277   EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE));
278   EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE));
279 }
280 
281 // Check that SCEV expander does not use the nuw instruction
282 // for expansion.
283 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNUW) {
284   /*
285    * Create the following code:
286    * func(i64 %a)
287    * entry:
288    *   br false, label %exit, label %body
289    * body:
290    *  %s1 = add i64 %a, -1
291    *  br label %exit
292    * exit:
293    *  %s = add nuw i64 %a, -1
294    *  ret %s
295    */
296 
297   // Create a module.
298   Module M("SCEVExpanderNUW", Context);
299 
300   Type *T_int64 = Type::getInt64Ty(Context);
301 
302   FunctionType *FTy =
303       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
304   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
305   Argument *Arg = &*F->arg_begin();
306   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
307 
308   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
309   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
310   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
311 
312   IRBuilder<> Builder(Entry);
313   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
314   Builder.CreateCondBr(Cond, Exit, Body);
315 
316   Builder.SetInsertPoint(Body);
317   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
318   Builder.CreateBr(Exit);
319 
320   Builder.SetInsertPoint(Exit);
321   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
322   S2->setHasNoUnsignedWrap(true);
323   auto *R = cast<Instruction>(Builder.CreateRetVoid());
324 
325   ScalarEvolution SE = buildSE(*F);
326   const SCEV *S = SE.getSCEV(S1);
327   EXPECT_TRUE(isa<SCEVAddExpr>(S));
328   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
329   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
330   EXPECT_FALSE(I->hasNoUnsignedWrap());
331 }
332 
333 // Check that SCEV expander does not use the nsw instruction
334 // for expansion.
335 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNSW) {
336   /*
337    * Create the following code:
338    * func(i64 %a)
339    * entry:
340    *   br false, label %exit, label %body
341    * body:
342    *  %s1 = add i64 %a, -1
343    *  br label %exit
344    * exit:
345    *  %s = add nsw i64 %a, -1
346    *  ret %s
347    */
348 
349   // Create a module.
350   Module M("SCEVExpanderNSW", Context);
351 
352   Type *T_int64 = Type::getInt64Ty(Context);
353 
354   FunctionType *FTy =
355       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
356   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
357   Argument *Arg = &*F->arg_begin();
358   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
359 
360   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
361   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
362   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
363 
364   IRBuilder<> Builder(Entry);
365   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
366   Builder.CreateCondBr(Cond, Exit, Body);
367 
368   Builder.SetInsertPoint(Body);
369   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
370   Builder.CreateBr(Exit);
371 
372   Builder.SetInsertPoint(Exit);
373   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
374   S2->setHasNoSignedWrap(true);
375   auto *R = cast<Instruction>(Builder.CreateRetVoid());
376 
377   ScalarEvolution SE = buildSE(*F);
378   const SCEV *S = SE.getSCEV(S1);
379   EXPECT_TRUE(isa<SCEVAddExpr>(S));
380   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
381   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
382   EXPECT_FALSE(I->hasNoSignedWrap());
383 }
384 
385 // Check that SCEV does not save the SCEV -> V
386 // mapping of SCEV differ from V in NUW flag.
387 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNUW) {
388   /*
389    * Create the following code:
390    * func(i64 %a)
391    * entry:
392    *  %s1 = add i64 %a, -1
393    *  %s2 = add nuw i64 %a, -1
394    *  br label %exit
395    * exit:
396    *  ret %s
397    */
398 
399   // Create a module.
400   Module M("SCEVCacheNUW", Context);
401 
402   Type *T_int64 = Type::getInt64Ty(Context);
403 
404   FunctionType *FTy =
405       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
406   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
407   Argument *Arg = &*F->arg_begin();
408   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
409 
410   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
411   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
412 
413   IRBuilder<> Builder(Entry);
414   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
415   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
416   S2->setHasNoUnsignedWrap(true);
417   Builder.CreateBr(Exit);
418 
419   Builder.SetInsertPoint(Exit);
420   auto *R = cast<Instruction>(Builder.CreateRetVoid());
421 
422   ScalarEvolution SE = buildSE(*F);
423   // Get S2 first to move it to cache.
424   const SCEV *SC2 = SE.getSCEV(S2);
425   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
426   // Now get S1.
427   const SCEV *SC1 = SE.getSCEV(S1);
428   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
429   // Expand for S1, it should use S1 not S2 in spite S2
430   // first in the cache.
431   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
432   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
433   EXPECT_FALSE(I->hasNoUnsignedWrap());
434 }
435 
436 // Check that SCEV does not save the SCEV -> V
437 // mapping of SCEV differ from V in NSW flag.
438 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNSW) {
439   /*
440    * Create the following code:
441    * func(i64 %a)
442    * entry:
443    *  %s1 = add i64 %a, -1
444    *  %s2 = add nsw i64 %a, -1
445    *  br label %exit
446    * exit:
447    *  ret %s
448    */
449 
450   // Create a module.
451   Module M("SCEVCacheNUW", Context);
452 
453   Type *T_int64 = Type::getInt64Ty(Context);
454 
455   FunctionType *FTy =
456       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
457   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
458   Argument *Arg = &*F->arg_begin();
459   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
460 
461   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
462   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
463 
464   IRBuilder<> Builder(Entry);
465   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
466   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
467   S2->setHasNoSignedWrap(true);
468   Builder.CreateBr(Exit);
469 
470   Builder.SetInsertPoint(Exit);
471   auto *R = cast<Instruction>(Builder.CreateRetVoid());
472 
473   ScalarEvolution SE = buildSE(*F);
474   // Get S2 first to move it to cache.
475   const SCEV *SC2 = SE.getSCEV(S2);
476   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
477   // Now get S1.
478   const SCEV *SC1 = SE.getSCEV(S1);
479   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
480   // Expand for S1, it should use S1 not S2 in spite S2
481   // first in the cache.
482   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
483   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
484   EXPECT_FALSE(I->hasNoSignedWrap());
485 }
486 
487 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandInsertCanonicalIV) {
488   LLVMContext C;
489   SMDiagnostic Err;
490 
491   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
492   // SCEVExpander will insert one.
493   auto TestNoCanonicalIV =
494       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)>
495               GetAddRec) {
496         std::unique_ptr<Module> M = parseAssemblyString(
497             "define i32 @test(i32 %limit) { "
498             "entry: "
499             "  br label %loop "
500             "loop: "
501             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
502             "  %i.inc = add nsw i32 %i, 1 "
503             "  %cont = icmp slt i32 %i.inc, %limit "
504             "  br i1 %cont, label %loop, label %exit "
505             "exit: "
506             "  ret i32 %i.inc "
507             "}",
508             Err, C);
509 
510         assert(M && "Could not parse module?");
511         assert(!verifyModule(*M) && "Must have been well formed!");
512 
513         runWithSE(
514             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
515               auto &I = GetInstByName(F, "i");
516               auto *Loop = LI.getLoopFor(I.getParent());
517               EXPECT_FALSE(Loop->getCanonicalInductionVariable());
518 
519               auto *AR = GetAddRec(SE, Loop);
520               unsigned ExpectedCanonicalIVWidth =
521                   SE.getTypeSizeInBits(AR->getType());
522 
523               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
524               auto *InsertAt = I.getNextNode();
525               Exp.expandCodeFor(AR, nullptr, InsertAt);
526               PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
527               unsigned CanonicalIVBitWidth =
528                   cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
529               EXPECT_EQ(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
530             });
531       };
532 
533   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
534   // which is narrower than addrec type.
535   // SCEVExpander will insert a canonical IV of a wider type to expand the
536   // addrec.
537   auto TestNarrowCanonicalIV = [&](std::function<const SCEV *(
538                                        ScalarEvolution & SE, Loop * L)>
539                                        GetAddRec) {
540     std::unique_ptr<Module> M = parseAssemblyString(
541         "define i32 @test(i32 %limit) { "
542         "entry: "
543         "  br label %loop "
544         "loop: "
545         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
546         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
547         "  %i.inc = add nsw i32 %i, 1 "
548         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
549         "  %cont = icmp slt i32 %i.inc, %limit "
550         "  br i1 %cont, label %loop, label %exit "
551         "exit: "
552         "  ret i32 %i.inc "
553         "}",
554         Err, C);
555 
556     assert(M && "Could not parse module?");
557     assert(!verifyModule(*M) && "Must have been well formed!");
558 
559     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
560       auto &I = GetInstByName(F, "i");
561 
562       auto *LoopHeaderBB = I.getParent();
563       auto *Loop = LI.getLoopFor(LoopHeaderBB);
564       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
565       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
566 
567       auto *AR = GetAddRec(SE, Loop);
568 
569       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
570       unsigned CanonicalIVBitWidth =
571           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
572       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
573 
574       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
575       auto *InsertAt = I.getNextNode();
576       Exp.expandCodeFor(AR, nullptr, InsertAt);
577 
578       // Loop over all of the PHI nodes, looking for the new canonical indvar.
579       PHINode *NewCanonicalIV = nullptr;
580       for (BasicBlock::iterator i = LoopHeaderBB->begin(); isa<PHINode>(i);
581            ++i) {
582         PHINode *PN = cast<PHINode>(i);
583         if (PN == &I || PN == CanonicalIV)
584           continue;
585         // We expect that the only PHI added is the new canonical IV
586         EXPECT_FALSE(NewCanonicalIV);
587         NewCanonicalIV = PN;
588       }
589 
590       // Check that NewCanonicalIV is a canonical IV, i.e {0,+,1}
591       BasicBlock *Incoming = nullptr, *Backedge = nullptr;
592       EXPECT_TRUE(Loop->getIncomingAndBackEdge(Incoming, Backedge));
593       auto *Start = NewCanonicalIV->getIncomingValueForBlock(Incoming);
594       EXPECT_TRUE(isa<ConstantInt>(Start));
595       EXPECT_TRUE(dyn_cast<ConstantInt>(Start)->isZero());
596       auto *Next = NewCanonicalIV->getIncomingValueForBlock(Backedge);
597       EXPECT_TRUE(isa<BinaryOperator>(Next));
598       auto *NextBinOp = dyn_cast<BinaryOperator>(Next);
599       EXPECT_EQ(NextBinOp->getOpcode(), Instruction::Add);
600       EXPECT_EQ(NextBinOp->getOperand(0), NewCanonicalIV);
601       auto *Step = NextBinOp->getOperand(1);
602       EXPECT_TRUE(isa<ConstantInt>(Step));
603       EXPECT_TRUE(dyn_cast<ConstantInt>(Step)->isOne());
604 
605       unsigned NewCanonicalIVBitWidth =
606           cast<IntegerType>(NewCanonicalIV->getType())->getBitWidth();
607       EXPECT_EQ(NewCanonicalIVBitWidth, ExpectedCanonicalIVWidth);
608     });
609   };
610 
611   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
612   // of addrec width.
613   // To expand the addrec SCEVExpander should use the existing canonical IV.
614   auto TestMatchingCanonicalIV =
615       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)> GetAddRec,
616           unsigned ARBitWidth) {
617         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
618         std::unique_ptr<Module> M = parseAssemblyString(
619             "define i32 @test(i32 %limit) { "
620             "entry: "
621             "  br label %loop "
622             "loop: "
623             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
624             "  %canonical.iv = phi " +
625                 ARBitWidthTypeStr +
626                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
627                 "  %i.inc = add nsw i32 %i, 1 "
628                 "  %canonical.iv.inc = add " +
629                 ARBitWidthTypeStr +
630                 " %canonical.iv, 1 "
631                 "  %cont = icmp slt i32 %i.inc, %limit "
632                 "  br i1 %cont, label %loop, label %exit "
633                 "exit: "
634                 "  ret i32 %i.inc "
635                 "}",
636             Err, C);
637 
638         assert(M && "Could not parse module?");
639         assert(!verifyModule(*M) && "Must have been well formed!");
640 
641         runWithSE(
642             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
643               auto &I = GetInstByName(F, "i");
644               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
645 
646               auto *LoopHeaderBB = I.getParent();
647               auto *Loop = LI.getLoopFor(LoopHeaderBB);
648               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
649               unsigned CanonicalIVBitWidth =
650                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
651 
652               auto *AR = GetAddRec(SE, Loop);
653               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
654               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
655 
656               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
657               auto *InsertAt = I.getNextNode();
658               Exp.expandCodeFor(AR, nullptr, InsertAt);
659 
660               // Loop over all of the PHI nodes, looking if a new canonical
661               // indvar was introduced.
662               PHINode *NewCanonicalIV = nullptr;
663               for (BasicBlock::iterator i = LoopHeaderBB->begin();
664                    isa<PHINode>(i); ++i) {
665                 PHINode *PN = cast<PHINode>(i);
666                 if (PN == &I || PN == &CanonicalIV)
667                   continue;
668                 NewCanonicalIV = PN;
669               }
670               EXPECT_FALSE(NewCanonicalIV);
671             });
672       };
673 
674   unsigned ARBitWidth = 16;
675   Type *ARType = IntegerType::get(C, ARBitWidth);
676 
677   // Expand {5,+,1}
678   auto GetAR2 = [&](ScalarEvolution &SE, Loop *L) -> const SCEV * {
679     return SE.getAddRecExpr(SE.getConstant(APInt(ARBitWidth, 5)),
680                             SE.getOne(ARType), L, SCEV::FlagAnyWrap);
681   };
682   TestNoCanonicalIV(GetAR2);
683   TestNarrowCanonicalIV(GetAR2);
684   TestMatchingCanonicalIV(GetAR2, ARBitWidth);
685 }
686 
687 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderShlNSW) {
688 
689   auto checkOneCase = [this](std::string &&str) {
690     LLVMContext C;
691     SMDiagnostic Err;
692     std::unique_ptr<Module> M = parseAssemblyString(str, Err, C);
693 
694     assert(M && "Could not parse module?");
695     assert(!verifyModule(*M) && "Must have been well formed!");
696 
697     Function *F = M->getFunction("f");
698     ASSERT_NE(F, nullptr) << "Could not find function 'f'";
699 
700     BasicBlock &Entry = F->getEntryBlock();
701     LoadInst *Load = cast<LoadInst>(&Entry.front());
702     BinaryOperator *And = cast<BinaryOperator>(*Load->user_begin());
703 
704     ScalarEvolution SE = buildSE(*F);
705     const SCEV *AndSCEV = SE.getSCEV(And);
706     EXPECT_TRUE(isa<SCEVMulExpr>(AndSCEV));
707     EXPECT_TRUE(cast<SCEVMulExpr>(AndSCEV)->hasNoSignedWrap());
708 
709     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
710     auto *I = cast<Instruction>(Exp.expandCodeFor(AndSCEV, nullptr, And));
711     EXPECT_EQ(I->getOpcode(), Instruction::Shl);
712     EXPECT_FALSE(I->hasNoSignedWrap());
713   };
714 
715   checkOneCase("define void @f(i16* %arrayidx) { "
716                "  %1 = load i16, i16* %arrayidx "
717                "  %2 = and i16 %1, -32768 "
718                "  ret void "
719                "} ");
720 
721   checkOneCase("define void @f(i8* %arrayidx) { "
722                "  %1 = load i8, i8* %arrayidx "
723                "  %2 = and i8 %1, -128 "
724                "  ret void "
725                "} ");
726 }
727 
728 // Test expansion of nested addrecs in CanonicalMode.
729 // Expanding nested addrecs in canonical mode requiers a canonical IV of a
730 // type wider than the type of the addrec itself. Currently, SCEVExpander
731 // just falls back to literal mode for nested addrecs.
732 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandNonAffineAddRec) {
733   LLVMContext C;
734   SMDiagnostic Err;
735 
736   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
737   auto TestNoCanonicalIV =
738       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
739               GetAddRec) {
740         std::unique_ptr<Module> M = parseAssemblyString(
741             "define i32 @test(i32 %limit) { "
742             "entry: "
743             "  br label %loop "
744             "loop: "
745             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
746             "  %i.inc = add nsw i32 %i, 1 "
747             "  %cont = icmp slt i32 %i.inc, %limit "
748             "  br i1 %cont, label %loop, label %exit "
749             "exit: "
750             "  ret i32 %i.inc "
751             "}",
752             Err, C);
753 
754         assert(M && "Could not parse module?");
755         assert(!verifyModule(*M) && "Must have been well formed!");
756 
757         runWithSE(*M, "test",
758                   [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
759                     auto &I = GetInstByName(F, "i");
760                     auto *Loop = LI.getLoopFor(I.getParent());
761                     EXPECT_FALSE(Loop->getCanonicalInductionVariable());
762 
763                     auto *AR = GetAddRec(SE, Loop);
764                     EXPECT_FALSE(AR->isAffine());
765 
766                     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
767                     auto *InsertAt = I.getNextNode();
768                     Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
769                     auto *ExpandedAR = SE.getSCEV(V);
770                     // Check that the expansion happened literally.
771                     EXPECT_EQ(AR, ExpandedAR);
772                   });
773       };
774 
775   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
776   // which is narrower than addrec type.
777   auto TestNarrowCanonicalIV = [&](std::function<const SCEVAddRecExpr *(
778                                        ScalarEvolution & SE, Loop * L)>
779                                        GetAddRec) {
780     std::unique_ptr<Module> M = parseAssemblyString(
781         "define i32 @test(i32 %limit) { "
782         "entry: "
783         "  br label %loop "
784         "loop: "
785         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
786         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
787         "  %i.inc = add nsw i32 %i, 1 "
788         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
789         "  %cont = icmp slt i32 %i.inc, %limit "
790         "  br i1 %cont, label %loop, label %exit "
791         "exit: "
792         "  ret i32 %i.inc "
793         "}",
794         Err, C);
795 
796     assert(M && "Could not parse module?");
797     assert(!verifyModule(*M) && "Must have been well formed!");
798 
799     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
800       auto &I = GetInstByName(F, "i");
801 
802       auto *LoopHeaderBB = I.getParent();
803       auto *Loop = LI.getLoopFor(LoopHeaderBB);
804       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
805       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
806 
807       auto *AR = GetAddRec(SE, Loop);
808       EXPECT_FALSE(AR->isAffine());
809 
810       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
811       unsigned CanonicalIVBitWidth =
812           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
813       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
814 
815       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
816       auto *InsertAt = I.getNextNode();
817       Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
818       auto *ExpandedAR = SE.getSCEV(V);
819       // Check that the expansion happened literally.
820       EXPECT_EQ(AR, ExpandedAR);
821     });
822   };
823 
824   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
825   // of addrec width.
826   auto TestMatchingCanonicalIV =
827       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
828               GetAddRec,
829           unsigned ARBitWidth) {
830         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
831         std::unique_ptr<Module> M = parseAssemblyString(
832             "define i32 @test(i32 %limit) { "
833             "entry: "
834             "  br label %loop "
835             "loop: "
836             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
837             "  %canonical.iv = phi " +
838                 ARBitWidthTypeStr +
839                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
840                 "  %i.inc = add nsw i32 %i, 1 "
841                 "  %canonical.iv.inc = add " +
842                 ARBitWidthTypeStr +
843                 " %canonical.iv, 1 "
844                 "  %cont = icmp slt i32 %i.inc, %limit "
845                 "  br i1 %cont, label %loop, label %exit "
846                 "exit: "
847                 "  ret i32 %i.inc "
848                 "}",
849             Err, C);
850 
851         assert(M && "Could not parse module?");
852         assert(!verifyModule(*M) && "Must have been well formed!");
853 
854         runWithSE(
855             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
856               auto &I = GetInstByName(F, "i");
857               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
858 
859               auto *LoopHeaderBB = I.getParent();
860               auto *Loop = LI.getLoopFor(LoopHeaderBB);
861               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
862               unsigned CanonicalIVBitWidth =
863                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
864 
865               auto *AR = GetAddRec(SE, Loop);
866               EXPECT_FALSE(AR->isAffine());
867               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
868               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
869 
870               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
871               auto *InsertAt = I.getNextNode();
872               Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
873               auto *ExpandedAR = SE.getSCEV(V);
874               // Check that the expansion happened literally.
875               EXPECT_EQ(AR, ExpandedAR);
876             });
877       };
878 
879   unsigned ARBitWidth = 16;
880   Type *ARType = IntegerType::get(C, ARBitWidth);
881 
882   // Expand {5,+,1,+,1}
883   auto GetAR3 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
884     SmallVector<const SCEV *, 3> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
885                                         SE.getOne(ARType), SE.getOne(ARType)};
886     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
887   };
888   TestNoCanonicalIV(GetAR3);
889   TestNarrowCanonicalIV(GetAR3);
890   TestMatchingCanonicalIV(GetAR3, ARBitWidth);
891 
892   // Expand {5,+,1,+,1,+,1}
893   auto GetAR4 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
894     SmallVector<const SCEV *, 4> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
895                                         SE.getOne(ARType), SE.getOne(ARType),
896                                         SE.getOne(ARType)};
897     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
898   };
899   TestNoCanonicalIV(GetAR4);
900   TestNarrowCanonicalIV(GetAR4);
901   TestMatchingCanonicalIV(GetAR4, ARBitWidth);
902 
903   // Expand {5,+,1,+,1,+,1,+,1}
904   auto GetAR5 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
905     SmallVector<const SCEV *, 5> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
906                                         SE.getOne(ARType), SE.getOne(ARType),
907                                         SE.getOne(ARType), SE.getOne(ARType)};
908     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
909   };
910   TestNoCanonicalIV(GetAR5);
911   TestNarrowCanonicalIV(GetAR5);
912   TestMatchingCanonicalIV(GetAR5, ARBitWidth);
913 }
914 
915 } // end namespace llvm
916