xref: /llvm-project/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp (revision edf46f365cf4e7caccd7459ac1a6912de5096857)
1 //=== ScalarEvolutionExpanderTest.cpp - ScalarEvolutionExpander unit tests ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/IR/Verifier.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "gtest/gtest.h"
27 
28 namespace llvm {
29 
30 using namespace PatternMatch;
31 
32 // We use this fixture to ensure that we clean up ScalarEvolution before
33 // deleting the PassManager.
34 class ScalarEvolutionExpanderTest : public testing::Test {
35 protected:
36   LLVMContext Context;
37   Module M;
38   TargetLibraryInfoImpl TLII;
39   TargetLibraryInfo TLI;
40 
41   std::unique_ptr<AssumptionCache> AC;
42   std::unique_ptr<DominatorTree> DT;
43   std::unique_ptr<LoopInfo> LI;
44 
45   ScalarEvolutionExpanderTest() : M("", Context), TLII(), TLI(TLII) {}
46 
47   ScalarEvolution buildSE(Function &F) {
48     AC.reset(new AssumptionCache(F));
49     DT.reset(new DominatorTree(F));
50     LI.reset(new LoopInfo(*DT));
51     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
52   }
53 
54   void runWithSE(
55       Module &M, StringRef FuncName,
56       function_ref<void(Function &F, LoopInfo &LI, ScalarEvolution &SE)> Test) {
57     auto *F = M.getFunction(FuncName);
58     ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
59     ScalarEvolution SE = buildSE(*F);
60     Test(*F, *LI, SE);
61   }
62 };
63 
64 static Instruction &GetInstByName(Function &F, StringRef Name) {
65   for (auto &I : instructions(F))
66     if (I.getName() == Name)
67       return I;
68   llvm_unreachable("Could not find instructions!");
69 }
70 
71 TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
72   // It is to test the fix for PR30213. It exercises the branch in scev
73   // expansion when the value in ValueOffsetPair is a ptr and the offset
74   // is not divisible by the elem type size of value.
75   auto *I8Ty = Type::getInt8Ty(Context);
76   auto *I8PtrTy = PointerType::get(Context, 0);
77   auto *I32Ty = Type::getInt32Ty(Context);
78   auto *I32PtrTy = PointerType::get(Context, 0);
79   FunctionType *FTy =
80       FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
81   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
82   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
83   BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
84   BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
85   BranchInst::Create(LoopBB, EntryBB);
86   ReturnInst::Create(Context, nullptr, ExitBB);
87 
88   // loop:                            ; preds = %loop, %entry
89   //   %alloca = alloca i32
90   //   %gep0 = getelementptr i32, i32* %alloca, i32 1
91   //   %bitcast1 = bitcast i32* %gep0 to i8*
92   //   %gep1 = getelementptr i8, i8* %bitcast1, i32 1
93   //   %gep2 = getelementptr i8, i8* undef, i32 1
94   //   %cmp = icmp ult i8* undef, %bitcast1
95   //   %select = select i1 %cmp, i8* %gep1, i8* %gep2
96   //   %bitcast2 = bitcast i8* %select to i32*
97   //   br i1 undef, label %loop, label %exit
98 
99   const DataLayout &DL = F->getDataLayout();
100   BranchInst *Br = BranchInst::Create(
101       LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), LoopBB);
102   AllocaInst *Alloca =
103       new AllocaInst(I32Ty, DL.getAllocaAddrSpace(), "alloca", Br);
104   ConstantInt *Ci32 = ConstantInt::get(Context, APInt(32, 1));
105   GetElementPtrInst *Gep0 =
106       GetElementPtrInst::Create(I32Ty, Alloca, Ci32, "gep0", Br);
107   CastInst *CastA =
108       CastInst::CreateBitOrPointerCast(Gep0, I8PtrTy, "bitcast1", Br);
109   GetElementPtrInst *Gep1 =
110       GetElementPtrInst::Create(I8Ty, CastA, Ci32, "gep1", Br);
111   GetElementPtrInst *Gep2 = GetElementPtrInst::Create(
112       I8Ty, UndefValue::get(I8PtrTy), Ci32, "gep2", Br);
113   CmpInst *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
114                                  UndefValue::get(I8PtrTy), CastA, "cmp", Br);
115   SelectInst *Sel = SelectInst::Create(Cmp, Gep1, Gep2, "select", Br);
116   CastInst *CastB =
117       CastInst::CreateBitOrPointerCast(Sel, I32PtrTy, "bitcast2", Br);
118 
119   ScalarEvolution SE = buildSE(*F);
120   const SCEV *S = SE.getSCEV(CastB);
121   EXPECT_TRUE(isa<SCEVUnknown>(S));
122 }
123 
124 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
125 TEST_F(ScalarEvolutionExpanderTest, SCEVZeroExtendExprNonIntegral) {
126   /*
127    * Create the following code:
128    * func(i64 addrspace(10)* %arg)
129    * top:
130    *  br label %L.ph
131    * L.ph:
132    *  %gepbase = getelementptr i64 addrspace(10)* %arg, i64 1
133    *  br label %L
134    * L:
135    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
136    *  %add = add i64 %phi2, 1
137    *  br i1 undef, label %post, label %L2
138    * post:
139    *  #= %gep = getelementptr i64 addrspace(10)* %gepbase, i64 %add =#
140    *  ret void
141    *
142    * We will create the appropriate SCEV expression for %gep and expand it,
143    * then check that no inttoptr/ptrtoint instructions got inserted.
144    */
145 
146   // Create a module with non-integral pointers in it's datalayout
147   Module NIM("nonintegral", Context);
148   std::string DataLayout = M.getDataLayoutStr();
149   if (!DataLayout.empty())
150     DataLayout += "-";
151   DataLayout += "ni:10";
152   NIM.setDataLayout(DataLayout);
153 
154   Type *T_int1 = Type::getInt1Ty(Context);
155   Type *T_int64 = Type::getInt64Ty(Context);
156   Type *T_pint64 = PointerType::get(Context, 10);
157 
158   FunctionType *FTy =
159       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
160   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
161 
162   Argument *Arg = &*F->arg_begin();
163 
164   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
165   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
166   BasicBlock *L = BasicBlock::Create(Context, "L", F);
167   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
168 
169   IRBuilder<> Builder(Top);
170   Builder.CreateBr(LPh);
171 
172   Builder.SetInsertPoint(LPh);
173   Value *GepBase =
174       Builder.CreateGEP(T_int64, Arg, ConstantInt::get(T_int64, 1));
175   Builder.CreateBr(L);
176 
177   Builder.SetInsertPoint(L);
178   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
179   Value *Add = Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add");
180   Builder.CreateCondBr(UndefValue::get(T_int1), L, Post);
181   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
182   Phi->addIncoming(Add, L);
183 
184   Builder.SetInsertPoint(Post);
185   Instruction *Ret = Builder.CreateRetVoid();
186 
187   ScalarEvolution SE = buildSE(*F);
188   const SCEV *AddRec =
189       SE.getAddRecExpr(SE.getUnknown(GepBase), SE.getConstant(T_int64, 1),
190                        LI->getLoopFor(L), SCEV::FlagNUW);
191 
192   SCEVExpander Exp(SE, NIM.getDataLayout(), "expander");
193   Exp.disableCanonicalMode();
194   Exp.expandCodeFor(AddRec, T_pint64, Ret);
195 
196   // Make sure none of the instructions inserted were inttoptr/ptrtoint.
197   // The verifier will check this.
198   EXPECT_FALSE(verifyFunction(*F, &errs()));
199 }
200 
201 // Check that we can correctly identify the points at which the SCEV of the
202 // AddRec can be expanded.
203 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
204   /*
205    * Create the following code:
206    * func(i64 addrspace(10)* %arg)
207    * top:
208    *  br label %L.ph
209    * L.ph:
210    *  br label %L
211    * L:
212    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
213    *  %add = add i64 %phi2, 1
214    *  %cond = icmp slt i64 %add, 1000; then becomes 2000.
215    *  br i1 %cond, label %post, label %L2
216    * post:
217    *  ret void
218    *
219    */
220 
221   // Create a module with non-integral pointers in it's datalayout
222   Module NIM("nonintegral", Context);
223   std::string DataLayout = M.getDataLayoutStr();
224   if (!DataLayout.empty())
225     DataLayout += "-";
226   DataLayout += "ni:10";
227   NIM.setDataLayout(DataLayout);
228 
229   Type *T_int64 = Type::getInt64Ty(Context);
230   Type *T_pint64 = PointerType::get(Context, 10);
231 
232   FunctionType *FTy =
233       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
234   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
235 
236   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
237   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
238   BasicBlock *L = BasicBlock::Create(Context, "L", F);
239   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
240 
241   IRBuilder<> Builder(Top);
242   Builder.CreateBr(LPh);
243 
244   Builder.SetInsertPoint(LPh);
245   Builder.CreateBr(L);
246 
247   Builder.SetInsertPoint(L);
248   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
249   auto *Add = cast<Instruction>(
250       Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
251   auto *Limit = ConstantInt::get(T_int64, 1000);
252   auto *Cond = cast<Instruction>(
253       Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Limit, "cond"));
254   Builder.CreateCondBr(Cond, L, Post);
255   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
256   Phi->addIncoming(Add, L);
257 
258   Builder.SetInsertPoint(Post);
259   Instruction *Ret = Builder.CreateRetVoid();
260 
261   ScalarEvolution SE = buildSE(*F);
262   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
263   const SCEV *S = SE.getSCEV(Phi);
264   EXPECT_TRUE(isa<SCEVAddRecExpr>(S));
265   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
266   EXPECT_TRUE(AR->isAffine());
267   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, Top->getTerminator()));
268   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, LPh->getTerminator()));
269   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, L->getTerminator()));
270   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, Post->getTerminator()));
271 
272   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
273   Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
274   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
275 }
276 
277 // Check that SCEV expander does not use the nuw instruction
278 // for expansion.
279 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNUW) {
280   /*
281    * Create the following code:
282    * func(i64 %a)
283    * entry:
284    *   br false, label %exit, label %body
285    * body:
286    *  %s1 = add i64 %a, -1
287    *  br label %exit
288    * exit:
289    *  %s = add nuw i64 %a, -1
290    *  ret %s
291    */
292 
293   // Create a module.
294   Module M("SCEVExpanderNUW", Context);
295 
296   Type *T_int64 = Type::getInt64Ty(Context);
297 
298   FunctionType *FTy =
299       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
300   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
301   Argument *Arg = &*F->arg_begin();
302   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
303 
304   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
305   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
306   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
307 
308   IRBuilder<> Builder(Entry);
309   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
310   Builder.CreateCondBr(Cond, Exit, Body);
311 
312   Builder.SetInsertPoint(Body);
313   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
314   Builder.CreateBr(Exit);
315 
316   Builder.SetInsertPoint(Exit);
317   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
318   S2->setHasNoUnsignedWrap(true);
319   auto *R = cast<Instruction>(Builder.CreateRetVoid());
320 
321   ScalarEvolution SE = buildSE(*F);
322   const SCEV *S = SE.getSCEV(S1);
323   EXPECT_TRUE(isa<SCEVAddExpr>(S));
324   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
325   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
326   EXPECT_FALSE(I->hasNoUnsignedWrap());
327 }
328 
329 // Check that SCEV expander does not use the nsw instruction
330 // for expansion.
331 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNSW) {
332   /*
333    * Create the following code:
334    * func(i64 %a)
335    * entry:
336    *   br false, label %exit, label %body
337    * body:
338    *  %s1 = add i64 %a, -1
339    *  br label %exit
340    * exit:
341    *  %s = add nsw i64 %a, -1
342    *  ret %s
343    */
344 
345   // Create a module.
346   Module M("SCEVExpanderNSW", Context);
347 
348   Type *T_int64 = Type::getInt64Ty(Context);
349 
350   FunctionType *FTy =
351       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
352   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
353   Argument *Arg = &*F->arg_begin();
354   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
355 
356   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
357   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
358   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
359 
360   IRBuilder<> Builder(Entry);
361   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
362   Builder.CreateCondBr(Cond, Exit, Body);
363 
364   Builder.SetInsertPoint(Body);
365   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
366   Builder.CreateBr(Exit);
367 
368   Builder.SetInsertPoint(Exit);
369   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
370   S2->setHasNoSignedWrap(true);
371   auto *R = cast<Instruction>(Builder.CreateRetVoid());
372 
373   ScalarEvolution SE = buildSE(*F);
374   const SCEV *S = SE.getSCEV(S1);
375   EXPECT_TRUE(isa<SCEVAddExpr>(S));
376   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
377   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
378   EXPECT_FALSE(I->hasNoSignedWrap());
379 }
380 
381 // Check that SCEV does not save the SCEV -> V
382 // mapping of SCEV differ from V in NUW flag.
383 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNUW) {
384   /*
385    * Create the following code:
386    * func(i64 %a)
387    * entry:
388    *  %s1 = add i64 %a, -1
389    *  %s2 = add nuw i64 %a, -1
390    *  br label %exit
391    * exit:
392    *  ret %s
393    */
394 
395   // Create a module.
396   Module M("SCEVCacheNUW", Context);
397 
398   Type *T_int64 = Type::getInt64Ty(Context);
399 
400   FunctionType *FTy =
401       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
402   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
403   Argument *Arg = &*F->arg_begin();
404   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
405 
406   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
407   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
408 
409   IRBuilder<> Builder(Entry);
410   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
411   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
412   S2->setHasNoUnsignedWrap(true);
413   Builder.CreateBr(Exit);
414 
415   Builder.SetInsertPoint(Exit);
416   auto *R = cast<Instruction>(Builder.CreateRetVoid());
417 
418   ScalarEvolution SE = buildSE(*F);
419   // Get S2 first to move it to cache.
420   const SCEV *SC2 = SE.getSCEV(S2);
421   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
422   // Now get S1.
423   const SCEV *SC1 = SE.getSCEV(S1);
424   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
425   // Expand for S1, it should use S1 not S2 in spite S2
426   // first in the cache.
427   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
428   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
429   EXPECT_FALSE(I->hasNoUnsignedWrap());
430 }
431 
432 // Check that SCEV does not save the SCEV -> V
433 // mapping of SCEV differ from V in NSW flag.
434 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNSW) {
435   /*
436    * Create the following code:
437    * func(i64 %a)
438    * entry:
439    *  %s1 = add i64 %a, -1
440    *  %s2 = add nsw i64 %a, -1
441    *  br label %exit
442    * exit:
443    *  ret %s
444    */
445 
446   // Create a module.
447   Module M("SCEVCacheNUW", Context);
448 
449   Type *T_int64 = Type::getInt64Ty(Context);
450 
451   FunctionType *FTy =
452       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
453   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
454   Argument *Arg = &*F->arg_begin();
455   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
456 
457   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
458   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
459 
460   IRBuilder<> Builder(Entry);
461   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
462   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
463   S2->setHasNoSignedWrap(true);
464   Builder.CreateBr(Exit);
465 
466   Builder.SetInsertPoint(Exit);
467   auto *R = cast<Instruction>(Builder.CreateRetVoid());
468 
469   ScalarEvolution SE = buildSE(*F);
470   // Get S2 first to move it to cache.
471   const SCEV *SC2 = SE.getSCEV(S2);
472   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
473   // Now get S1.
474   const SCEV *SC1 = SE.getSCEV(S1);
475   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
476   // Expand for S1, it should use S1 not S2 in spite S2
477   // first in the cache.
478   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
479   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
480   EXPECT_FALSE(I->hasNoSignedWrap());
481 }
482 
483 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandInsertCanonicalIV) {
484   LLVMContext C;
485   SMDiagnostic Err;
486 
487   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
488   // SCEVExpander will insert one.
489   auto TestNoCanonicalIV =
490       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)>
491               GetAddRec) {
492         std::unique_ptr<Module> M = parseAssemblyString(
493             "define i32 @test(i32 %limit) { "
494             "entry: "
495             "  br label %loop "
496             "loop: "
497             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
498             "  %i.inc = add nsw i32 %i, 1 "
499             "  %cont = icmp slt i32 %i.inc, %limit "
500             "  br i1 %cont, label %loop, label %exit "
501             "exit: "
502             "  ret i32 %i.inc "
503             "}",
504             Err, C);
505 
506         assert(M && "Could not parse module?");
507         assert(!verifyModule(*M) && "Must have been well formed!");
508 
509         runWithSE(
510             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
511               auto &I = GetInstByName(F, "i");
512               auto *Loop = LI.getLoopFor(I.getParent());
513               EXPECT_FALSE(Loop->getCanonicalInductionVariable());
514 
515               auto *AR = GetAddRec(SE, Loop);
516               unsigned ExpectedCanonicalIVWidth =
517                   SE.getTypeSizeInBits(AR->getType());
518 
519               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
520               auto *InsertAt = I.getNextNode();
521               Exp.expandCodeFor(AR, nullptr, InsertAt);
522               PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
523               unsigned CanonicalIVBitWidth =
524                   cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
525               EXPECT_EQ(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
526             });
527       };
528 
529   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
530   // which is narrower than addrec type.
531   // SCEVExpander will insert a canonical IV of a wider type to expand the
532   // addrec.
533   auto TestNarrowCanonicalIV = [&](std::function<const SCEV *(
534                                        ScalarEvolution & SE, Loop * L)>
535                                        GetAddRec) {
536     std::unique_ptr<Module> M = parseAssemblyString(
537         "define i32 @test(i32 %limit) { "
538         "entry: "
539         "  br label %loop "
540         "loop: "
541         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
542         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
543         "  %i.inc = add nsw i32 %i, 1 "
544         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
545         "  %cont = icmp slt i32 %i.inc, %limit "
546         "  br i1 %cont, label %loop, label %exit "
547         "exit: "
548         "  ret i32 %i.inc "
549         "}",
550         Err, C);
551 
552     assert(M && "Could not parse module?");
553     assert(!verifyModule(*M) && "Must have been well formed!");
554 
555     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
556       auto &I = GetInstByName(F, "i");
557 
558       auto *LoopHeaderBB = I.getParent();
559       auto *Loop = LI.getLoopFor(LoopHeaderBB);
560       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
561       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
562 
563       auto *AR = GetAddRec(SE, Loop);
564 
565       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
566       unsigned CanonicalIVBitWidth =
567           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
568       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
569 
570       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
571       auto *InsertAt = I.getNextNode();
572       Exp.expandCodeFor(AR, nullptr, InsertAt);
573 
574       // Loop over all of the PHI nodes, looking for the new canonical indvar.
575       PHINode *NewCanonicalIV = nullptr;
576       for (BasicBlock::iterator i = LoopHeaderBB->begin(); isa<PHINode>(i);
577            ++i) {
578         PHINode *PN = cast<PHINode>(i);
579         if (PN == &I || PN == CanonicalIV)
580           continue;
581         // We expect that the only PHI added is the new canonical IV
582         EXPECT_FALSE(NewCanonicalIV);
583         NewCanonicalIV = PN;
584       }
585 
586       // Check that NewCanonicalIV is a canonical IV, i.e {0,+,1}
587       BasicBlock *Incoming = nullptr, *Backedge = nullptr;
588       EXPECT_TRUE(Loop->getIncomingAndBackEdge(Incoming, Backedge));
589       auto *Start = NewCanonicalIV->getIncomingValueForBlock(Incoming);
590       EXPECT_TRUE(isa<ConstantInt>(Start));
591       EXPECT_TRUE(dyn_cast<ConstantInt>(Start)->isZero());
592       auto *Next = NewCanonicalIV->getIncomingValueForBlock(Backedge);
593       EXPECT_TRUE(isa<BinaryOperator>(Next));
594       auto *NextBinOp = dyn_cast<BinaryOperator>(Next);
595       EXPECT_EQ(NextBinOp->getOpcode(), Instruction::Add);
596       EXPECT_EQ(NextBinOp->getOperand(0), NewCanonicalIV);
597       auto *Step = NextBinOp->getOperand(1);
598       EXPECT_TRUE(isa<ConstantInt>(Step));
599       EXPECT_TRUE(dyn_cast<ConstantInt>(Step)->isOne());
600 
601       unsigned NewCanonicalIVBitWidth =
602           cast<IntegerType>(NewCanonicalIV->getType())->getBitWidth();
603       EXPECT_EQ(NewCanonicalIVBitWidth, ExpectedCanonicalIVWidth);
604     });
605   };
606 
607   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
608   // of addrec width.
609   // To expand the addrec SCEVExpander should use the existing canonical IV.
610   auto TestMatchingCanonicalIV =
611       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)> GetAddRec,
612           unsigned ARBitWidth) {
613         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
614         std::unique_ptr<Module> M = parseAssemblyString(
615             "define i32 @test(i32 %limit) { "
616             "entry: "
617             "  br label %loop "
618             "loop: "
619             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
620             "  %canonical.iv = phi " +
621                 ARBitWidthTypeStr +
622                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
623                 "  %i.inc = add nsw i32 %i, 1 "
624                 "  %canonical.iv.inc = add " +
625                 ARBitWidthTypeStr +
626                 " %canonical.iv, 1 "
627                 "  %cont = icmp slt i32 %i.inc, %limit "
628                 "  br i1 %cont, label %loop, label %exit "
629                 "exit: "
630                 "  ret i32 %i.inc "
631                 "}",
632             Err, C);
633 
634         assert(M && "Could not parse module?");
635         assert(!verifyModule(*M) && "Must have been well formed!");
636 
637         runWithSE(
638             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
639               auto &I = GetInstByName(F, "i");
640               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
641 
642               auto *LoopHeaderBB = I.getParent();
643               auto *Loop = LI.getLoopFor(LoopHeaderBB);
644               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
645               unsigned CanonicalIVBitWidth =
646                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
647 
648               auto *AR = GetAddRec(SE, Loop);
649               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
650               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
651 
652               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
653               auto *InsertAt = I.getNextNode();
654               Exp.expandCodeFor(AR, nullptr, InsertAt);
655 
656               // Loop over all of the PHI nodes, looking if a new canonical
657               // indvar was introduced.
658               PHINode *NewCanonicalIV = nullptr;
659               for (BasicBlock::iterator i = LoopHeaderBB->begin();
660                    isa<PHINode>(i); ++i) {
661                 PHINode *PN = cast<PHINode>(i);
662                 if (PN == &I || PN == &CanonicalIV)
663                   continue;
664                 NewCanonicalIV = PN;
665               }
666               EXPECT_FALSE(NewCanonicalIV);
667             });
668       };
669 
670   unsigned ARBitWidth = 16;
671   Type *ARType = IntegerType::get(C, ARBitWidth);
672 
673   // Expand {5,+,1}
674   auto GetAR2 = [&](ScalarEvolution &SE, Loop *L) -> const SCEV * {
675     return SE.getAddRecExpr(SE.getConstant(APInt(ARBitWidth, 5)),
676                             SE.getOne(ARType), L, SCEV::FlagAnyWrap);
677   };
678   TestNoCanonicalIV(GetAR2);
679   TestNarrowCanonicalIV(GetAR2);
680   TestMatchingCanonicalIV(GetAR2, ARBitWidth);
681 }
682 
683 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderShlNSW) {
684 
685   auto checkOneCase = [this](std::string &&str) {
686     LLVMContext C;
687     SMDiagnostic Err;
688     std::unique_ptr<Module> M = parseAssemblyString(str, Err, C);
689 
690     assert(M && "Could not parse module?");
691     assert(!verifyModule(*M) && "Must have been well formed!");
692 
693     Function *F = M->getFunction("f");
694     ASSERT_NE(F, nullptr) << "Could not find function 'f'";
695 
696     BasicBlock &Entry = F->getEntryBlock();
697     LoadInst *Load = cast<LoadInst>(&Entry.front());
698     BinaryOperator *And = cast<BinaryOperator>(*Load->user_begin());
699 
700     ScalarEvolution SE = buildSE(*F);
701     const SCEV *AndSCEV = SE.getSCEV(And);
702     EXPECT_TRUE(isa<SCEVMulExpr>(AndSCEV));
703     EXPECT_TRUE(cast<SCEVMulExpr>(AndSCEV)->hasNoSignedWrap());
704 
705     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
706     auto *I = cast<Instruction>(Exp.expandCodeFor(AndSCEV, nullptr, And));
707     EXPECT_EQ(I->getOpcode(), Instruction::Shl);
708     EXPECT_FALSE(I->hasNoSignedWrap());
709   };
710 
711   checkOneCase("define void @f(i16* %arrayidx) { "
712                "  %1 = load i16, i16* %arrayidx "
713                "  %2 = and i16 %1, -32768 "
714                "  ret void "
715                "} ");
716 
717   checkOneCase("define void @f(i8* %arrayidx) { "
718                "  %1 = load i8, i8* %arrayidx "
719                "  %2 = and i8 %1, -128 "
720                "  ret void "
721                "} ");
722 }
723 
724 // Test expansion of nested addrecs in CanonicalMode.
725 // Expanding nested addrecs in canonical mode requiers a canonical IV of a
726 // type wider than the type of the addrec itself. Currently, SCEVExpander
727 // just falls back to literal mode for nested addrecs.
728 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandNonAffineAddRec) {
729   LLVMContext C;
730   SMDiagnostic Err;
731 
732   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
733   auto TestNoCanonicalIV =
734       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
735               GetAddRec) {
736         std::unique_ptr<Module> M = parseAssemblyString(
737             "define i32 @test(i32 %limit) { "
738             "entry: "
739             "  br label %loop "
740             "loop: "
741             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
742             "  %i.inc = add nsw i32 %i, 1 "
743             "  %cont = icmp slt i32 %i.inc, %limit "
744             "  br i1 %cont, label %loop, label %exit "
745             "exit: "
746             "  ret i32 %i.inc "
747             "}",
748             Err, C);
749 
750         assert(M && "Could not parse module?");
751         assert(!verifyModule(*M) && "Must have been well formed!");
752 
753         runWithSE(*M, "test",
754                   [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
755                     auto &I = GetInstByName(F, "i");
756                     auto *Loop = LI.getLoopFor(I.getParent());
757                     EXPECT_FALSE(Loop->getCanonicalInductionVariable());
758 
759                     auto *AR = GetAddRec(SE, Loop);
760                     EXPECT_FALSE(AR->isAffine());
761 
762                     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
763                     auto *InsertAt = I.getNextNode();
764                     Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
765                     const SCEV *ExpandedAR = SE.getSCEV(V);
766                     // Check that the expansion happened literally.
767                     EXPECT_EQ(AR, ExpandedAR);
768                   });
769       };
770 
771   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
772   // which is narrower than addrec type.
773   auto TestNarrowCanonicalIV = [&](std::function<const SCEVAddRecExpr *(
774                                        ScalarEvolution & SE, Loop * L)>
775                                        GetAddRec) {
776     std::unique_ptr<Module> M = parseAssemblyString(
777         "define i32 @test(i32 %limit) { "
778         "entry: "
779         "  br label %loop "
780         "loop: "
781         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
782         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
783         "  %i.inc = add nsw i32 %i, 1 "
784         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
785         "  %cont = icmp slt i32 %i.inc, %limit "
786         "  br i1 %cont, label %loop, label %exit "
787         "exit: "
788         "  ret i32 %i.inc "
789         "}",
790         Err, C);
791 
792     assert(M && "Could not parse module?");
793     assert(!verifyModule(*M) && "Must have been well formed!");
794 
795     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
796       auto &I = GetInstByName(F, "i");
797 
798       auto *LoopHeaderBB = I.getParent();
799       auto *Loop = LI.getLoopFor(LoopHeaderBB);
800       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
801       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
802 
803       auto *AR = GetAddRec(SE, Loop);
804       EXPECT_FALSE(AR->isAffine());
805 
806       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
807       unsigned CanonicalIVBitWidth =
808           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
809       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
810 
811       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
812       auto *InsertAt = I.getNextNode();
813       Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
814       const SCEV *ExpandedAR = SE.getSCEV(V);
815       // Check that the expansion happened literally.
816       EXPECT_EQ(AR, ExpandedAR);
817     });
818   };
819 
820   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
821   // of addrec width.
822   auto TestMatchingCanonicalIV =
823       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
824               GetAddRec,
825           unsigned ARBitWidth) {
826         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
827         std::unique_ptr<Module> M = parseAssemblyString(
828             "define i32 @test(i32 %limit) { "
829             "entry: "
830             "  br label %loop "
831             "loop: "
832             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
833             "  %canonical.iv = phi " +
834                 ARBitWidthTypeStr +
835                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
836                 "  %i.inc = add nsw i32 %i, 1 "
837                 "  %canonical.iv.inc = add " +
838                 ARBitWidthTypeStr +
839                 " %canonical.iv, 1 "
840                 "  %cont = icmp slt i32 %i.inc, %limit "
841                 "  br i1 %cont, label %loop, label %exit "
842                 "exit: "
843                 "  ret i32 %i.inc "
844                 "}",
845             Err, C);
846 
847         assert(M && "Could not parse module?");
848         assert(!verifyModule(*M) && "Must have been well formed!");
849 
850         runWithSE(
851             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
852               auto &I = GetInstByName(F, "i");
853               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
854 
855               auto *LoopHeaderBB = I.getParent();
856               auto *Loop = LI.getLoopFor(LoopHeaderBB);
857               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
858               unsigned CanonicalIVBitWidth =
859                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
860 
861               auto *AR = GetAddRec(SE, Loop);
862               EXPECT_FALSE(AR->isAffine());
863               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
864               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
865 
866               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
867               auto *InsertAt = I.getNextNode();
868               Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
869               const SCEV *ExpandedAR = SE.getSCEV(V);
870               // Check that the expansion happened literally.
871               EXPECT_EQ(AR, ExpandedAR);
872             });
873       };
874 
875   unsigned ARBitWidth = 16;
876   Type *ARType = IntegerType::get(C, ARBitWidth);
877 
878   // Expand {5,+,1,+,1}
879   auto GetAR3 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
880     SmallVector<const SCEV *, 3> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
881                                         SE.getOne(ARType), SE.getOne(ARType)};
882     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
883   };
884   TestNoCanonicalIV(GetAR3);
885   TestNarrowCanonicalIV(GetAR3);
886   TestMatchingCanonicalIV(GetAR3, ARBitWidth);
887 
888   // Expand {5,+,1,+,1,+,1}
889   auto GetAR4 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
890     SmallVector<const SCEV *, 4> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
891                                         SE.getOne(ARType), SE.getOne(ARType),
892                                         SE.getOne(ARType)};
893     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
894   };
895   TestNoCanonicalIV(GetAR4);
896   TestNarrowCanonicalIV(GetAR4);
897   TestMatchingCanonicalIV(GetAR4, ARBitWidth);
898 
899   // Expand {5,+,1,+,1,+,1,+,1}
900   auto GetAR5 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
901     SmallVector<const SCEV *, 5> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
902                                         SE.getOne(ARType), SE.getOne(ARType),
903                                         SE.getOne(ARType), SE.getOne(ARType)};
904     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
905   };
906   TestNoCanonicalIV(GetAR5);
907   TestNarrowCanonicalIV(GetAR5);
908   TestMatchingCanonicalIV(GetAR5, ARBitWidth);
909 }
910 
911 TEST_F(ScalarEvolutionExpanderTest, ExpandNonIntegralPtrWithNullBase) {
912   LLVMContext C;
913   SMDiagnostic Err;
914 
915   std::unique_ptr<Module> M =
916       parseAssemblyString("target datalayout = "
917                           "\"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:"
918                           "128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2\""
919                           "define ptr addrspace(1) @test(i64 %offset) { "
920                           "  %ptr = getelementptr inbounds float, ptr "
921                           "addrspace(1) null, i64 %offset"
922                           "  ret ptr addrspace(1) %ptr"
923                           "}",
924                           Err, C);
925 
926   assert(M && "Could not parse module?");
927   assert(!verifyModule(*M) && "Must have been well formed!");
928 
929   runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
930     auto &I = GetInstByName(F, "ptr");
931     auto PtrPlus1 =
932         SE.getAddExpr(SE.getSCEV(&I), SE.getConstant(I.getType(), 1));
933     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
934 
935     Value *V = Exp.expandCodeFor(PtrPlus1, I.getType(), &I);
936     I.replaceAllUsesWith(V);
937 
938     // Check that the expander created:
939     // define ptr addrspace(1) @test(i64 %off) {
940     //   %1 = shl i64 %offset, 2
941     //   %2 = add nuw nsw i64 %1, 1
942     //   %uglygep = getelementptr i8, ptr addrspace(1) null, i64 %2
943     //   %ptr = getelementptr inbounds float, ptr addrspace(1) null, i64 %off
944     //   ret ptr addrspace(1) %uglygep
945     // }
946 
947     Value *Offset = &*F.arg_begin();
948     auto *GEP = dyn_cast<GetElementPtrInst>(V);
949     EXPECT_TRUE(GEP);
950     EXPECT_TRUE(cast<Constant>(GEP->getPointerOperand())->isNullValue());
951     EXPECT_EQ(GEP->getNumOperands(), 2U);
952     EXPECT_TRUE(match(
953         GEP->getOperand(1),
954         m_Add(m_Shl(m_Specific(Offset), m_SpecificInt(2)), m_SpecificInt(1))));
955     EXPECT_EQ(cast<PointerType>(GEP->getPointerOperand()->getType())
956                   ->getAddressSpace(),
957               cast<PointerType>(I.getType())->getAddressSpace());
958     EXPECT_FALSE(verifyFunction(F, &errs()));
959   });
960 }
961 
962 } // end namespace llvm
963