xref: /llvm-project/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp (revision 2b1122eaec0d0b335feeac4adfaab2e3e04b7bd9)
1 //=== ScalarEvolutionExpanderTest.cpp - ScalarEvolutionExpander unit tests ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstIterator.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/IR/Verifier.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "gtest/gtest.h"
27 
28 namespace llvm {
29 
30 using namespace PatternMatch;
31 
32 // We use this fixture to ensure that we clean up ScalarEvolution before
33 // deleting the PassManager.
34 class ScalarEvolutionExpanderTest : public testing::Test {
35 protected:
36   LLVMContext Context;
37   Module M;
38   TargetLibraryInfoImpl TLII;
39   TargetLibraryInfo TLI;
40 
41   std::unique_ptr<AssumptionCache> AC;
42   std::unique_ptr<DominatorTree> DT;
43   std::unique_ptr<LoopInfo> LI;
44 
45   ScalarEvolutionExpanderTest() : M("", Context), TLII(), TLI(TLII) {}
46 
47   ScalarEvolution buildSE(Function &F) {
48     AC.reset(new AssumptionCache(F));
49     DT.reset(new DominatorTree(F));
50     LI.reset(new LoopInfo(*DT));
51     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
52   }
53 
54   void runWithSE(
55       Module &M, StringRef FuncName,
56       function_ref<void(Function &F, LoopInfo &LI, ScalarEvolution &SE)> Test) {
57     auto *F = M.getFunction(FuncName);
58     ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
59     ScalarEvolution SE = buildSE(*F);
60     Test(*F, *LI, SE);
61   }
62 };
63 
64 static Instruction &GetInstByName(Function &F, StringRef Name) {
65   for (auto &I : instructions(F))
66     if (I.getName() == Name)
67       return I;
68   llvm_unreachable("Could not find instructions!");
69 }
70 
71 TEST_F(ScalarEvolutionExpanderTest, ExpandPtrTypeSCEV) {
72   // It is to test the fix for PR30213. It exercises the branch in scev
73   // expansion when the value in ValueOffsetPair is a ptr and the offset
74   // is not divisible by the elem type size of value.
75   auto *I8Ty = Type::getInt8Ty(Context);
76   auto *I8PtrTy = PointerType::get(Context, 0);
77   auto *I32Ty = Type::getInt32Ty(Context);
78   auto *I32PtrTy = PointerType::get(Context, 0);
79   FunctionType *FTy =
80       FunctionType::get(Type::getVoidTy(Context), std::vector<Type *>(), false);
81   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
82   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
83   BasicBlock *LoopBB = BasicBlock::Create(Context, "loop", F);
84   BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F);
85   BranchInst::Create(LoopBB, EntryBB);
86   ReturnInst::Create(Context, nullptr, ExitBB);
87 
88   // loop:                            ; preds = %loop, %entry
89   //   %alloca = alloca i32
90   //   %gep0 = getelementptr i32, i32* %alloca, i32 1
91   //   %bitcast1 = bitcast i32* %gep0 to i8*
92   //   %gep1 = getelementptr i8, i8* %bitcast1, i32 1
93   //   %gep2 = getelementptr i8, i8* undef, i32 1
94   //   %cmp = icmp ult i8* undef, %bitcast1
95   //   %select = select i1 %cmp, i8* %gep1, i8* %gep2
96   //   %bitcast2 = bitcast i8* %select to i32*
97   //   br i1 undef, label %loop, label %exit
98 
99   const DataLayout &DL = F->getDataLayout();
100   BranchInst *Br = BranchInst::Create(
101       LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), LoopBB);
102   AllocaInst *Alloca = new AllocaInst(I32Ty, DL.getAllocaAddrSpace(), "alloca",
103                                       Br->getIterator());
104   ConstantInt *Ci32 = ConstantInt::get(Context, APInt(32, 1));
105   GetElementPtrInst *Gep0 =
106       GetElementPtrInst::Create(I32Ty, Alloca, Ci32, "gep0", Br->getIterator());
107   CastInst *CastA = CastInst::CreateBitOrPointerCast(Gep0, I8PtrTy, "bitcast1",
108                                                      Br->getIterator());
109   GetElementPtrInst *Gep1 =
110       GetElementPtrInst::Create(I8Ty, CastA, Ci32, "gep1", Br->getIterator());
111   GetElementPtrInst *Gep2 = GetElementPtrInst::Create(
112       I8Ty, UndefValue::get(I8PtrTy), Ci32, "gep2", Br->getIterator());
113   CmpInst *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
114                                  UndefValue::get(I8PtrTy), CastA, "cmp",
115                                  Br->getIterator());
116   SelectInst *Sel =
117       SelectInst::Create(Cmp, Gep1, Gep2, "select", Br->getIterator());
118   CastInst *CastB = CastInst::CreateBitOrPointerCast(Sel, I32PtrTy, "bitcast2",
119                                                      Br->getIterator());
120 
121   ScalarEvolution SE = buildSE(*F);
122   const SCEV *S = SE.getSCEV(CastB);
123   EXPECT_TRUE(isa<SCEVUnknown>(S));
124 }
125 
126 // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions
127 TEST_F(ScalarEvolutionExpanderTest, SCEVZeroExtendExprNonIntegral) {
128   /*
129    * Create the following code:
130    * func(i64 addrspace(10)* %arg)
131    * top:
132    *  br label %L.ph
133    * L.ph:
134    *  %gepbase = getelementptr i64 addrspace(10)* %arg, i64 1
135    *  br label %L
136    * L:
137    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
138    *  %add = add i64 %phi2, 1
139    *  br i1 undef, label %post, label %L2
140    * post:
141    *  #= %gep = getelementptr i64 addrspace(10)* %gepbase, i64 %add =#
142    *  ret void
143    *
144    * We will create the appropriate SCEV expression for %gep and expand it,
145    * then check that no inttoptr/ptrtoint instructions got inserted.
146    */
147 
148   // Create a module with non-integral pointers in it's datalayout
149   Module NIM("nonintegral", Context);
150   std::string DataLayout = M.getDataLayoutStr();
151   if (!DataLayout.empty())
152     DataLayout += "-";
153   DataLayout += "ni:10";
154   NIM.setDataLayout(DataLayout);
155 
156   Type *T_int1 = Type::getInt1Ty(Context);
157   Type *T_int64 = Type::getInt64Ty(Context);
158   Type *T_pint64 = PointerType::get(Context, 10);
159 
160   FunctionType *FTy =
161       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
162   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
163 
164   Argument *Arg = &*F->arg_begin();
165 
166   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
167   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
168   BasicBlock *L = BasicBlock::Create(Context, "L", F);
169   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
170 
171   IRBuilder<> Builder(Top);
172   Builder.CreateBr(LPh);
173 
174   Builder.SetInsertPoint(LPh);
175   Value *GepBase =
176       Builder.CreateGEP(T_int64, Arg, ConstantInt::get(T_int64, 1));
177   Builder.CreateBr(L);
178 
179   Builder.SetInsertPoint(L);
180   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
181   Value *Add = Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add");
182   Builder.CreateCondBr(UndefValue::get(T_int1), L, Post);
183   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
184   Phi->addIncoming(Add, L);
185 
186   Builder.SetInsertPoint(Post);
187   Instruction *Ret = Builder.CreateRetVoid();
188 
189   ScalarEvolution SE = buildSE(*F);
190   const SCEV *AddRec =
191       SE.getAddRecExpr(SE.getUnknown(GepBase), SE.getConstant(T_int64, 1),
192                        LI->getLoopFor(L), SCEV::FlagNUW);
193 
194   SCEVExpander Exp(SE, NIM.getDataLayout(), "expander");
195   Exp.disableCanonicalMode();
196   Exp.expandCodeFor(AddRec, T_pint64, Ret);
197 
198   // Make sure none of the instructions inserted were inttoptr/ptrtoint.
199   // The verifier will check this.
200   EXPECT_FALSE(verifyFunction(*F, &errs()));
201 }
202 
203 // Check that we can correctly identify the points at which the SCEV of the
204 // AddRec can be expanded.
205 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
206   /*
207    * Create the following code:
208    * func(i64 addrspace(10)* %arg)
209    * top:
210    *  br label %L.ph
211    * L.ph:
212    *  br label %L
213    * L:
214    *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
215    *  %add = add i64 %phi2, 1
216    *  %cond = icmp slt i64 %add, 1000; then becomes 2000.
217    *  br i1 %cond, label %post, label %L2
218    * post:
219    *  ret void
220    *
221    */
222 
223   // Create a module with non-integral pointers in it's datalayout
224   Module NIM("nonintegral", Context);
225   std::string DataLayout = M.getDataLayoutStr();
226   if (!DataLayout.empty())
227     DataLayout += "-";
228   DataLayout += "ni:10";
229   NIM.setDataLayout(DataLayout);
230 
231   Type *T_int64 = Type::getInt64Ty(Context);
232   Type *T_pint64 = PointerType::get(Context, 10);
233 
234   FunctionType *FTy =
235       FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
236   Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", NIM);
237 
238   BasicBlock *Top = BasicBlock::Create(Context, "top", F);
239   BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
240   BasicBlock *L = BasicBlock::Create(Context, "L", F);
241   BasicBlock *Post = BasicBlock::Create(Context, "post", F);
242 
243   IRBuilder<> Builder(Top);
244   Builder.CreateBr(LPh);
245 
246   Builder.SetInsertPoint(LPh);
247   Builder.CreateBr(L);
248 
249   Builder.SetInsertPoint(L);
250   PHINode *Phi = Builder.CreatePHI(T_int64, 2);
251   auto *Add = cast<Instruction>(
252       Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
253   auto *Limit = ConstantInt::get(T_int64, 1000);
254   auto *Cond = cast<Instruction>(
255       Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Limit, "cond"));
256   Builder.CreateCondBr(Cond, L, Post);
257   Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
258   Phi->addIncoming(Add, L);
259 
260   Builder.SetInsertPoint(Post);
261   Instruction *Ret = Builder.CreateRetVoid();
262 
263   ScalarEvolution SE = buildSE(*F);
264   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
265   const SCEV *S = SE.getSCEV(Phi);
266   EXPECT_TRUE(isa<SCEVAddRecExpr>(S));
267   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
268   EXPECT_TRUE(AR->isAffine());
269   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, Top->getTerminator()));
270   EXPECT_FALSE(Exp.isSafeToExpandAt(AR, LPh->getTerminator()));
271   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, L->getTerminator()));
272   EXPECT_TRUE(Exp.isSafeToExpandAt(AR, Post->getTerminator()));
273 
274   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
275   Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
276   EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
277 }
278 
279 // Check that SCEV expander does not use the nuw instruction
280 // for expansion.
281 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNUW) {
282   /*
283    * Create the following code:
284    * func(i64 %a)
285    * entry:
286    *   br false, label %exit, label %body
287    * body:
288    *  %s1 = add i64 %a, -1
289    *  br label %exit
290    * exit:
291    *  %s = add nuw i64 %a, -1
292    *  ret %s
293    */
294 
295   // Create a module.
296   Module M("SCEVExpanderNUW", Context);
297 
298   Type *T_int64 = Type::getInt64Ty(Context);
299 
300   FunctionType *FTy =
301       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
302   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
303   Argument *Arg = &*F->arg_begin();
304   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
305 
306   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
307   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
308   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
309 
310   IRBuilder<> Builder(Entry);
311   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
312   Builder.CreateCondBr(Cond, Exit, Body);
313 
314   Builder.SetInsertPoint(Body);
315   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
316   Builder.CreateBr(Exit);
317 
318   Builder.SetInsertPoint(Exit);
319   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
320   S2->setHasNoUnsignedWrap(true);
321   auto *R = cast<Instruction>(Builder.CreateRetVoid());
322 
323   ScalarEvolution SE = buildSE(*F);
324   const SCEV *S = SE.getSCEV(S1);
325   EXPECT_TRUE(isa<SCEVAddExpr>(S));
326   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
327   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
328   EXPECT_FALSE(I->hasNoUnsignedWrap());
329 }
330 
331 // Check that SCEV expander does not use the nsw instruction
332 // for expansion.
333 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderNSW) {
334   /*
335    * Create the following code:
336    * func(i64 %a)
337    * entry:
338    *   br false, label %exit, label %body
339    * body:
340    *  %s1 = add i64 %a, -1
341    *  br label %exit
342    * exit:
343    *  %s = add nsw i64 %a, -1
344    *  ret %s
345    */
346 
347   // Create a module.
348   Module M("SCEVExpanderNSW", Context);
349 
350   Type *T_int64 = Type::getInt64Ty(Context);
351 
352   FunctionType *FTy =
353       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
354   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
355   Argument *Arg = &*F->arg_begin();
356   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
357 
358   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
359   BasicBlock *Body = BasicBlock::Create(Context, "body", F);
360   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
361 
362   IRBuilder<> Builder(Entry);
363   ConstantInt *Cond = ConstantInt::get(Context, APInt(1, 0));
364   Builder.CreateCondBr(Cond, Exit, Body);
365 
366   Builder.SetInsertPoint(Body);
367   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
368   Builder.CreateBr(Exit);
369 
370   Builder.SetInsertPoint(Exit);
371   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
372   S2->setHasNoSignedWrap(true);
373   auto *R = cast<Instruction>(Builder.CreateRetVoid());
374 
375   ScalarEvolution SE = buildSE(*F);
376   const SCEV *S = SE.getSCEV(S1);
377   EXPECT_TRUE(isa<SCEVAddExpr>(S));
378   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
379   auto *I = cast<Instruction>(Exp.expandCodeFor(S, nullptr, R));
380   EXPECT_FALSE(I->hasNoSignedWrap());
381 }
382 
383 // Check that SCEV does not save the SCEV -> V
384 // mapping of SCEV differ from V in NUW flag.
385 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNUW) {
386   /*
387    * Create the following code:
388    * func(i64 %a)
389    * entry:
390    *  %s1 = add i64 %a, -1
391    *  %s2 = add nuw i64 %a, -1
392    *  br label %exit
393    * exit:
394    *  ret %s
395    */
396 
397   // Create a module.
398   Module M("SCEVCacheNUW", Context);
399 
400   Type *T_int64 = Type::getInt64Ty(Context);
401 
402   FunctionType *FTy =
403       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
404   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
405   Argument *Arg = &*F->arg_begin();
406   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
407 
408   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
409   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
410 
411   IRBuilder<> Builder(Entry);
412   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
413   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
414   S2->setHasNoUnsignedWrap(true);
415   Builder.CreateBr(Exit);
416 
417   Builder.SetInsertPoint(Exit);
418   auto *R = cast<Instruction>(Builder.CreateRetVoid());
419 
420   ScalarEvolution SE = buildSE(*F);
421   // Get S2 first to move it to cache.
422   const SCEV *SC2 = SE.getSCEV(S2);
423   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
424   // Now get S1.
425   const SCEV *SC1 = SE.getSCEV(S1);
426   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
427   // Expand for S1, it should use S1 not S2 in spite S2
428   // first in the cache.
429   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
430   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
431   EXPECT_FALSE(I->hasNoUnsignedWrap());
432 }
433 
434 // Check that SCEV does not save the SCEV -> V
435 // mapping of SCEV differ from V in NSW flag.
436 TEST_F(ScalarEvolutionExpanderTest, SCEVCacheNSW) {
437   /*
438    * Create the following code:
439    * func(i64 %a)
440    * entry:
441    *  %s1 = add i64 %a, -1
442    *  %s2 = add nsw i64 %a, -1
443    *  br label %exit
444    * exit:
445    *  ret %s
446    */
447 
448   // Create a module.
449   Module M("SCEVCacheNUW", Context);
450 
451   Type *T_int64 = Type::getInt64Ty(Context);
452 
453   FunctionType *FTy =
454       FunctionType::get(Type::getVoidTy(Context), {T_int64}, false);
455   Function *F = Function::Create(FTy, Function::ExternalLinkage, "func", M);
456   Argument *Arg = &*F->arg_begin();
457   ConstantInt *C = ConstantInt::get(Context, APInt(64, -1));
458 
459   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
460   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
461 
462   IRBuilder<> Builder(Entry);
463   auto *S1 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
464   auto *S2 = cast<Instruction>(Builder.CreateAdd(Arg, C, "add"));
465   S2->setHasNoSignedWrap(true);
466   Builder.CreateBr(Exit);
467 
468   Builder.SetInsertPoint(Exit);
469   auto *R = cast<Instruction>(Builder.CreateRetVoid());
470 
471   ScalarEvolution SE = buildSE(*F);
472   // Get S2 first to move it to cache.
473   const SCEV *SC2 = SE.getSCEV(S2);
474   EXPECT_TRUE(isa<SCEVAddExpr>(SC2));
475   // Now get S1.
476   const SCEV *SC1 = SE.getSCEV(S1);
477   EXPECT_TRUE(isa<SCEVAddExpr>(SC1));
478   // Expand for S1, it should use S1 not S2 in spite S2
479   // first in the cache.
480   SCEVExpander Exp(SE, M.getDataLayout(), "expander");
481   auto *I = cast<Instruction>(Exp.expandCodeFor(SC1, nullptr, R));
482   EXPECT_FALSE(I->hasNoSignedWrap());
483 }
484 
485 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandInsertCanonicalIV) {
486   LLVMContext C;
487   SMDiagnostic Err;
488 
489   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
490   // SCEVExpander will insert one.
491   auto TestNoCanonicalIV =
492       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)>
493               GetAddRec) {
494         std::unique_ptr<Module> M = parseAssemblyString(
495             "define i32 @test(i32 %limit) { "
496             "entry: "
497             "  br label %loop "
498             "loop: "
499             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
500             "  %i.inc = add nsw i32 %i, 1 "
501             "  %cont = icmp slt i32 %i.inc, %limit "
502             "  br i1 %cont, label %loop, label %exit "
503             "exit: "
504             "  ret i32 %i.inc "
505             "}",
506             Err, C);
507 
508         assert(M && "Could not parse module?");
509         assert(!verifyModule(*M) && "Must have been well formed!");
510 
511         runWithSE(
512             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
513               auto &I = GetInstByName(F, "i");
514               auto *Loop = LI.getLoopFor(I.getParent());
515               EXPECT_FALSE(Loop->getCanonicalInductionVariable());
516 
517               auto *AR = GetAddRec(SE, Loop);
518               unsigned ExpectedCanonicalIVWidth =
519                   SE.getTypeSizeInBits(AR->getType());
520 
521               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
522               auto *InsertAt = I.getNextNode();
523               Exp.expandCodeFor(AR, nullptr, InsertAt);
524               PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
525               unsigned CanonicalIVBitWidth =
526                   cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
527               EXPECT_EQ(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
528             });
529       };
530 
531   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
532   // which is narrower than addrec type.
533   // SCEVExpander will insert a canonical IV of a wider type to expand the
534   // addrec.
535   auto TestNarrowCanonicalIV = [&](std::function<const SCEV *(
536                                        ScalarEvolution & SE, Loop * L)>
537                                        GetAddRec) {
538     std::unique_ptr<Module> M = parseAssemblyString(
539         "define i32 @test(i32 %limit) { "
540         "entry: "
541         "  br label %loop "
542         "loop: "
543         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
544         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
545         "  %i.inc = add nsw i32 %i, 1 "
546         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
547         "  %cont = icmp slt i32 %i.inc, %limit "
548         "  br i1 %cont, label %loop, label %exit "
549         "exit: "
550         "  ret i32 %i.inc "
551         "}",
552         Err, C);
553 
554     assert(M && "Could not parse module?");
555     assert(!verifyModule(*M) && "Must have been well formed!");
556 
557     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
558       auto &I = GetInstByName(F, "i");
559 
560       auto *LoopHeaderBB = I.getParent();
561       auto *Loop = LI.getLoopFor(LoopHeaderBB);
562       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
563       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
564 
565       auto *AR = GetAddRec(SE, Loop);
566 
567       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
568       unsigned CanonicalIVBitWidth =
569           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
570       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
571 
572       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
573       auto *InsertAt = I.getNextNode();
574       Exp.expandCodeFor(AR, nullptr, InsertAt);
575 
576       // Loop over all of the PHI nodes, looking for the new canonical indvar.
577       PHINode *NewCanonicalIV = nullptr;
578       for (BasicBlock::iterator i = LoopHeaderBB->begin(); isa<PHINode>(i);
579            ++i) {
580         PHINode *PN = cast<PHINode>(i);
581         if (PN == &I || PN == CanonicalIV)
582           continue;
583         // We expect that the only PHI added is the new canonical IV
584         EXPECT_FALSE(NewCanonicalIV);
585         NewCanonicalIV = PN;
586       }
587 
588       // Check that NewCanonicalIV is a canonical IV, i.e {0,+,1}
589       BasicBlock *Incoming = nullptr, *Backedge = nullptr;
590       EXPECT_TRUE(Loop->getIncomingAndBackEdge(Incoming, Backedge));
591       auto *Start = NewCanonicalIV->getIncomingValueForBlock(Incoming);
592       EXPECT_TRUE(isa<ConstantInt>(Start));
593       EXPECT_TRUE(dyn_cast<ConstantInt>(Start)->isZero());
594       auto *Next = NewCanonicalIV->getIncomingValueForBlock(Backedge);
595       EXPECT_TRUE(isa<BinaryOperator>(Next));
596       auto *NextBinOp = dyn_cast<BinaryOperator>(Next);
597       EXPECT_EQ(NextBinOp->getOpcode(), Instruction::Add);
598       EXPECT_EQ(NextBinOp->getOperand(0), NewCanonicalIV);
599       auto *Step = NextBinOp->getOperand(1);
600       EXPECT_TRUE(isa<ConstantInt>(Step));
601       EXPECT_TRUE(dyn_cast<ConstantInt>(Step)->isOne());
602 
603       unsigned NewCanonicalIVBitWidth =
604           cast<IntegerType>(NewCanonicalIV->getType())->getBitWidth();
605       EXPECT_EQ(NewCanonicalIVBitWidth, ExpectedCanonicalIVWidth);
606     });
607   };
608 
609   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
610   // of addrec width.
611   // To expand the addrec SCEVExpander should use the existing canonical IV.
612   auto TestMatchingCanonicalIV =
613       [&](std::function<const SCEV *(ScalarEvolution & SE, Loop * L)> GetAddRec,
614           unsigned ARBitWidth) {
615         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
616         std::unique_ptr<Module> M = parseAssemblyString(
617             "define i32 @test(i32 %limit) { "
618             "entry: "
619             "  br label %loop "
620             "loop: "
621             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
622             "  %canonical.iv = phi " +
623                 ARBitWidthTypeStr +
624                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
625                 "  %i.inc = add nsw i32 %i, 1 "
626                 "  %canonical.iv.inc = add " +
627                 ARBitWidthTypeStr +
628                 " %canonical.iv, 1 "
629                 "  %cont = icmp slt i32 %i.inc, %limit "
630                 "  br i1 %cont, label %loop, label %exit "
631                 "exit: "
632                 "  ret i32 %i.inc "
633                 "}",
634             Err, C);
635 
636         assert(M && "Could not parse module?");
637         assert(!verifyModule(*M) && "Must have been well formed!");
638 
639         runWithSE(
640             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
641               auto &I = GetInstByName(F, "i");
642               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
643 
644               auto *LoopHeaderBB = I.getParent();
645               auto *Loop = LI.getLoopFor(LoopHeaderBB);
646               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
647               unsigned CanonicalIVBitWidth =
648                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
649 
650               auto *AR = GetAddRec(SE, Loop);
651               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
652               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
653 
654               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
655               auto *InsertAt = I.getNextNode();
656               Exp.expandCodeFor(AR, nullptr, InsertAt);
657 
658               // Loop over all of the PHI nodes, looking if a new canonical
659               // indvar was introduced.
660               PHINode *NewCanonicalIV = nullptr;
661               for (BasicBlock::iterator i = LoopHeaderBB->begin();
662                    isa<PHINode>(i); ++i) {
663                 PHINode *PN = cast<PHINode>(i);
664                 if (PN == &I || PN == &CanonicalIV)
665                   continue;
666                 NewCanonicalIV = PN;
667               }
668               EXPECT_FALSE(NewCanonicalIV);
669             });
670       };
671 
672   unsigned ARBitWidth = 16;
673   Type *ARType = IntegerType::get(C, ARBitWidth);
674 
675   // Expand {5,+,1}
676   auto GetAR2 = [&](ScalarEvolution &SE, Loop *L) -> const SCEV * {
677     return SE.getAddRecExpr(SE.getConstant(APInt(ARBitWidth, 5)),
678                             SE.getOne(ARType), L, SCEV::FlagAnyWrap);
679   };
680   TestNoCanonicalIV(GetAR2);
681   TestNarrowCanonicalIV(GetAR2);
682   TestMatchingCanonicalIV(GetAR2, ARBitWidth);
683 }
684 
685 TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderShlNSW) {
686 
687   auto checkOneCase = [this](std::string &&str) {
688     LLVMContext C;
689     SMDiagnostic Err;
690     std::unique_ptr<Module> M = parseAssemblyString(str, Err, C);
691 
692     assert(M && "Could not parse module?");
693     assert(!verifyModule(*M) && "Must have been well formed!");
694 
695     Function *F = M->getFunction("f");
696     ASSERT_NE(F, nullptr) << "Could not find function 'f'";
697 
698     BasicBlock &Entry = F->getEntryBlock();
699     LoadInst *Load = cast<LoadInst>(&Entry.front());
700     BinaryOperator *And = cast<BinaryOperator>(*Load->user_begin());
701 
702     ScalarEvolution SE = buildSE(*F);
703     const SCEV *AndSCEV = SE.getSCEV(And);
704     EXPECT_TRUE(isa<SCEVMulExpr>(AndSCEV));
705     EXPECT_TRUE(cast<SCEVMulExpr>(AndSCEV)->hasNoSignedWrap());
706 
707     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
708     auto *I = cast<Instruction>(Exp.expandCodeFor(AndSCEV, nullptr, And));
709     EXPECT_EQ(I->getOpcode(), Instruction::Shl);
710     EXPECT_FALSE(I->hasNoSignedWrap());
711   };
712 
713   checkOneCase("define void @f(i16* %arrayidx) { "
714                "  %1 = load i16, i16* %arrayidx "
715                "  %2 = and i16 %1, -32768 "
716                "  ret void "
717                "} ");
718 
719   checkOneCase("define void @f(i8* %arrayidx) { "
720                "  %1 = load i8, i8* %arrayidx "
721                "  %2 = and i8 %1, -128 "
722                "  ret void "
723                "} ");
724 }
725 
726 // Test expansion of nested addrecs in CanonicalMode.
727 // Expanding nested addrecs in canonical mode requiers a canonical IV of a
728 // type wider than the type of the addrec itself. Currently, SCEVExpander
729 // just falls back to literal mode for nested addrecs.
730 TEST_F(ScalarEvolutionExpanderTest, SCEVExpandNonAffineAddRec) {
731   LLVMContext C;
732   SMDiagnostic Err;
733 
734   // Expand the addrec produced by GetAddRec into a loop without a canonical IV.
735   auto TestNoCanonicalIV =
736       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
737               GetAddRec) {
738         std::unique_ptr<Module> M = parseAssemblyString(
739             "define i32 @test(i32 %limit) { "
740             "entry: "
741             "  br label %loop "
742             "loop: "
743             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
744             "  %i.inc = add nsw i32 %i, 1 "
745             "  %cont = icmp slt i32 %i.inc, %limit "
746             "  br i1 %cont, label %loop, label %exit "
747             "exit: "
748             "  ret i32 %i.inc "
749             "}",
750             Err, C);
751 
752         assert(M && "Could not parse module?");
753         assert(!verifyModule(*M) && "Must have been well formed!");
754 
755         runWithSE(*M, "test",
756                   [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
757                     auto &I = GetInstByName(F, "i");
758                     auto *Loop = LI.getLoopFor(I.getParent());
759                     EXPECT_FALSE(Loop->getCanonicalInductionVariable());
760 
761                     auto *AR = GetAddRec(SE, Loop);
762                     EXPECT_FALSE(AR->isAffine());
763 
764                     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
765                     auto *InsertAt = I.getNextNode();
766                     Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
767                     const SCEV *ExpandedAR = SE.getSCEV(V);
768                     // Check that the expansion happened literally.
769                     EXPECT_EQ(AR, ExpandedAR);
770                   });
771       };
772 
773   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
774   // which is narrower than addrec type.
775   auto TestNarrowCanonicalIV = [&](std::function<const SCEVAddRecExpr *(
776                                        ScalarEvolution & SE, Loop * L)>
777                                        GetAddRec) {
778     std::unique_ptr<Module> M = parseAssemblyString(
779         "define i32 @test(i32 %limit) { "
780         "entry: "
781         "  br label %loop "
782         "loop: "
783         "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
784         "  %canonical.iv = phi i8 [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
785         "  %i.inc = add nsw i32 %i, 1 "
786         "  %canonical.iv.inc = add i8 %canonical.iv, 1 "
787         "  %cont = icmp slt i32 %i.inc, %limit "
788         "  br i1 %cont, label %loop, label %exit "
789         "exit: "
790         "  ret i32 %i.inc "
791         "}",
792         Err, C);
793 
794     assert(M && "Could not parse module?");
795     assert(!verifyModule(*M) && "Must have been well formed!");
796 
797     runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
798       auto &I = GetInstByName(F, "i");
799 
800       auto *LoopHeaderBB = I.getParent();
801       auto *Loop = LI.getLoopFor(LoopHeaderBB);
802       PHINode *CanonicalIV = Loop->getCanonicalInductionVariable();
803       EXPECT_EQ(CanonicalIV, &GetInstByName(F, "canonical.iv"));
804 
805       auto *AR = GetAddRec(SE, Loop);
806       EXPECT_FALSE(AR->isAffine());
807 
808       unsigned ExpectedCanonicalIVWidth = SE.getTypeSizeInBits(AR->getType());
809       unsigned CanonicalIVBitWidth =
810           cast<IntegerType>(CanonicalIV->getType())->getBitWidth();
811       EXPECT_LT(CanonicalIVBitWidth, ExpectedCanonicalIVWidth);
812 
813       SCEVExpander Exp(SE, M->getDataLayout(), "expander");
814       auto *InsertAt = I.getNextNode();
815       Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
816       const SCEV *ExpandedAR = SE.getSCEV(V);
817       // Check that the expansion happened literally.
818       EXPECT_EQ(AR, ExpandedAR);
819     });
820   };
821 
822   // Expand the addrec produced by GetAddRec into a loop with a canonical IV
823   // of addrec width.
824   auto TestMatchingCanonicalIV =
825       [&](std::function<const SCEVAddRecExpr *(ScalarEvolution & SE, Loop * L)>
826               GetAddRec,
827           unsigned ARBitWidth) {
828         auto ARBitWidthTypeStr = "i" + std::to_string(ARBitWidth);
829         std::unique_ptr<Module> M = parseAssemblyString(
830             "define i32 @test(i32 %limit) { "
831             "entry: "
832             "  br label %loop "
833             "loop: "
834             "  %i = phi i32 [ 1, %entry ], [ %i.inc, %loop ] "
835             "  %canonical.iv = phi " +
836                 ARBitWidthTypeStr +
837                 " [ 0, %entry ], [ %canonical.iv.inc, %loop ] "
838                 "  %i.inc = add nsw i32 %i, 1 "
839                 "  %canonical.iv.inc = add " +
840                 ARBitWidthTypeStr +
841                 " %canonical.iv, 1 "
842                 "  %cont = icmp slt i32 %i.inc, %limit "
843                 "  br i1 %cont, label %loop, label %exit "
844                 "exit: "
845                 "  ret i32 %i.inc "
846                 "}",
847             Err, C);
848 
849         assert(M && "Could not parse module?");
850         assert(!verifyModule(*M) && "Must have been well formed!");
851 
852         runWithSE(
853             *M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
854               auto &I = GetInstByName(F, "i");
855               auto &CanonicalIV = GetInstByName(F, "canonical.iv");
856 
857               auto *LoopHeaderBB = I.getParent();
858               auto *Loop = LI.getLoopFor(LoopHeaderBB);
859               EXPECT_EQ(&CanonicalIV, Loop->getCanonicalInductionVariable());
860               unsigned CanonicalIVBitWidth =
861                   cast<IntegerType>(CanonicalIV.getType())->getBitWidth();
862 
863               auto *AR = GetAddRec(SE, Loop);
864               EXPECT_FALSE(AR->isAffine());
865               EXPECT_EQ(ARBitWidth, SE.getTypeSizeInBits(AR->getType()));
866               EXPECT_EQ(CanonicalIVBitWidth, ARBitWidth);
867 
868               SCEVExpander Exp(SE, M->getDataLayout(), "expander");
869               auto *InsertAt = I.getNextNode();
870               Value *V = Exp.expandCodeFor(AR, nullptr, InsertAt);
871               const SCEV *ExpandedAR = SE.getSCEV(V);
872               // Check that the expansion happened literally.
873               EXPECT_EQ(AR, ExpandedAR);
874             });
875       };
876 
877   unsigned ARBitWidth = 16;
878   Type *ARType = IntegerType::get(C, ARBitWidth);
879 
880   // Expand {5,+,1,+,1}
881   auto GetAR3 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
882     SmallVector<const SCEV *, 3> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
883                                         SE.getOne(ARType), SE.getOne(ARType)};
884     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
885   };
886   TestNoCanonicalIV(GetAR3);
887   TestNarrowCanonicalIV(GetAR3);
888   TestMatchingCanonicalIV(GetAR3, ARBitWidth);
889 
890   // Expand {5,+,1,+,1,+,1}
891   auto GetAR4 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
892     SmallVector<const SCEV *, 4> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
893                                         SE.getOne(ARType), SE.getOne(ARType),
894                                         SE.getOne(ARType)};
895     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
896   };
897   TestNoCanonicalIV(GetAR4);
898   TestNarrowCanonicalIV(GetAR4);
899   TestMatchingCanonicalIV(GetAR4, ARBitWidth);
900 
901   // Expand {5,+,1,+,1,+,1,+,1}
902   auto GetAR5 = [&](ScalarEvolution &SE, Loop *L) -> const SCEVAddRecExpr * {
903     SmallVector<const SCEV *, 5> Ops = {SE.getConstant(APInt(ARBitWidth, 5)),
904                                         SE.getOne(ARType), SE.getOne(ARType),
905                                         SE.getOne(ARType), SE.getOne(ARType)};
906     return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, L, SCEV::FlagAnyWrap));
907   };
908   TestNoCanonicalIV(GetAR5);
909   TestNarrowCanonicalIV(GetAR5);
910   TestMatchingCanonicalIV(GetAR5, ARBitWidth);
911 }
912 
913 TEST_F(ScalarEvolutionExpanderTest, ExpandNonIntegralPtrWithNullBase) {
914   LLVMContext C;
915   SMDiagnostic Err;
916 
917   std::unique_ptr<Module> M =
918       parseAssemblyString("target datalayout = "
919                           "\"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:"
920                           "128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2\""
921                           "define ptr addrspace(1) @test(i64 %offset) { "
922                           "  %ptr = getelementptr inbounds float, ptr "
923                           "addrspace(1) null, i64 %offset"
924                           "  ret ptr addrspace(1) %ptr"
925                           "}",
926                           Err, C);
927 
928   assert(M && "Could not parse module?");
929   assert(!verifyModule(*M) && "Must have been well formed!");
930 
931   runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
932     auto &I = GetInstByName(F, "ptr");
933     auto PtrPlus1 =
934         SE.getAddExpr(SE.getSCEV(&I), SE.getConstant(I.getType(), 1));
935     SCEVExpander Exp(SE, M->getDataLayout(), "expander");
936 
937     Value *V = Exp.expandCodeFor(PtrPlus1, I.getType(), &I);
938     I.replaceAllUsesWith(V);
939 
940     // Check that the expander created:
941     // define ptr addrspace(1) @test(i64 %off) {
942     //   %1 = shl i64 %offset, 2
943     //   %2 = add nuw nsw i64 %1, 1
944     //   %uglygep = getelementptr i8, ptr addrspace(1) null, i64 %2
945     //   %ptr = getelementptr inbounds float, ptr addrspace(1) null, i64 %off
946     //   ret ptr addrspace(1) %uglygep
947     // }
948 
949     Value *Offset = &*F.arg_begin();
950     auto *GEP = dyn_cast<GetElementPtrInst>(V);
951     EXPECT_TRUE(GEP);
952     EXPECT_TRUE(cast<Constant>(GEP->getPointerOperand())->isNullValue());
953     EXPECT_EQ(GEP->getNumOperands(), 2U);
954     EXPECT_TRUE(match(
955         GEP->getOperand(1),
956         m_Add(m_Shl(m_Specific(Offset), m_SpecificInt(2)), m_SpecificInt(1))));
957     EXPECT_EQ(cast<PointerType>(GEP->getPointerOperand()->getType())
958                   ->getAddressSpace(),
959               cast<PointerType>(I.getType())->getAddressSpace());
960     EXPECT_FALSE(verifyFunction(F, &errs()));
961   });
962 }
963 
964 } // end namespace llvm
965