xref: /llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp (revision 369f3039a3b5efb758e339a2452313f681cfc789)
1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements utilities for working with "normalized" expressions.
11 // See the comments at the top of ScalarEvolutionNormalization.h for details.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
18 using namespace llvm;
19 
20 /// TransformKind - Different types of transformations that
21 /// TransformForPostIncUse can do.
22 enum TransformKind {
23   /// Normalize - Normalize according to the given loops.
24   Normalize,
25   /// Denormalize - Perform the inverse transform on the expression with the
26   /// given loop set.
27   Denormalize
28 };
29 
30 typedef DenseMap<const SCEV *, const SCEV *> NormalizedCacheTy;
31 
32 static const SCEV *transformSubExpr(const TransformKind Kind,
33                                     NormalizePredTy Pred, ScalarEvolution &SE,
34                                     NormalizedCacheTy &Cache, const SCEV *S);
35 
36 /// Implement post-inc transformation for all valid expression types.
37 static const SCEV *transformImpl(const TransformKind Kind, NormalizePredTy Pred,
38                                  ScalarEvolution &SE, NormalizedCacheTy &Cache,
39                                  const SCEV *S) {
40   if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) {
41     const SCEV *O = X->getOperand();
42     const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
43     if (O != N)
44       switch (S->getSCEVType()) {
45       case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType());
46       case scSignExtend: return SE.getSignExtendExpr(N, S->getType());
47       case scTruncate: return SE.getTruncateExpr(N, S->getType());
48       default: llvm_unreachable("Unexpected SCEVCastExpr kind!");
49       }
50     return S;
51   }
52 
53   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
54     // An addrec. This is the interesting part.
55     SmallVector<const SCEV *, 8> Operands;
56 
57     transform(AR->operands(), std::back_inserter(Operands),
58               [&](const SCEV *Op) {
59                 return transformSubExpr(Kind, Pred, SE, Cache, Op);
60               });
61 
62     // Conservatively use AnyWrap until/unless we need FlagNW.
63     const SCEV *Result =
64         SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
65     switch (Kind) {
66     case Normalize:
67       // We want to normalize step expression, because otherwise we might not be
68       // able to denormalize to the original expression.
69       //
70       // Here is an example what will happen if we don't normalize step:
71       //  ORIGINAL ISE:
72       //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
73       //  NORMALIZED ISE:
74       //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
75       //     (100 /u {0,+,1}<%bb16>)}<%bb25>
76       //  DENORMALIZED BACK ISE:
77       //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
78       //     (100 /u {1,+,1}<%bb16>)}<%bb25>
79       //  Note that the initial value changes after normalization +
80       //  denormalization, which isn't correct.
81       if (Pred(AR)) {
82         const SCEV *TransformedStep =
83             transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
84         Result = SE.getMinusSCEV(Result, TransformedStep);
85       }
86 #if 0
87       // See the comment on the assert above.
88       assert(S == transformSubExpr(Result, User, OperandValToReplace) &&
89              "SCEV normalization is not invertible!");
90 #endif
91       break;
92     case Denormalize:
93       // Here we want to normalize step expressions for the same reasons, as
94       // stated above.
95       if (Pred(AR)) {
96         const SCEV *TransformedStep =
97             transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
98         Result = SE.getAddExpr(Result, TransformedStep);
99       }
100       break;
101     }
102     return Result;
103   }
104 
105   if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) {
106     SmallVector<const SCEV *, 8> Operands;
107     bool Changed = false;
108     // Transform each operand.
109     for (auto *O : X->operands()) {
110       const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
111       Changed |= N != O;
112       Operands.push_back(N);
113     }
114     // If any operand actually changed, return a transformed result.
115     if (Changed)
116       switch (S->getSCEVType()) {
117       case scAddExpr: return SE.getAddExpr(Operands);
118       case scMulExpr: return SE.getMulExpr(Operands);
119       case scSMaxExpr: return SE.getSMaxExpr(Operands);
120       case scUMaxExpr: return SE.getUMaxExpr(Operands);
121       default: llvm_unreachable("Unexpected SCEVNAryExpr kind!");
122       }
123     return S;
124   }
125 
126   if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) {
127     const SCEV *LO = X->getLHS();
128     const SCEV *RO = X->getRHS();
129     const SCEV *LN = transformSubExpr(Kind, Pred, SE, Cache, LO);
130     const SCEV *RN = transformSubExpr(Kind, Pred, SE, Cache, RO);
131     if (LO != LN || RO != RN)
132       return SE.getUDivExpr(LN, RN);
133     return S;
134   }
135 
136   llvm_unreachable("Unexpected SCEV kind!");
137 }
138 
139 /// Manage recursive transformation across an expression DAG. Revisiting
140 /// expressions would lead to exponential recursion.
141 static const SCEV *transformSubExpr(const TransformKind Kind,
142                                     NormalizePredTy Pred, ScalarEvolution &SE,
143                                     NormalizedCacheTy &Cache, const SCEV *S) {
144   if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S))
145     return S;
146 
147   const SCEV *Result = Cache.lookup(S);
148   if (Result)
149     return Result;
150 
151   Result = transformImpl(Kind, Pred, SE, Cache, S);
152   Cache[S] = Result;
153   return Result;
154 }
155 
156 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
157                                          const PostIncLoopSet &Loops,
158                                          ScalarEvolution &SE) {
159   auto Pred = [&](const SCEVAddRecExpr *AR) {
160     return Loops.count(AR->getLoop());
161   };
162   NormalizedCacheTy Cache;
163   return transformSubExpr(Normalize, Pred, SE, Cache, S);
164 }
165 
166 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
167                                            ScalarEvolution &SE) {
168   NormalizedCacheTy Cache;
169   return transformSubExpr(Normalize, Pred, SE, Cache, S);
170 }
171 
172 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
173                                            const PostIncLoopSet &Loops,
174                                            ScalarEvolution &SE) {
175   auto Pred = [&](const SCEVAddRecExpr *AR) {
176     return Loops.count(AR->getLoop());
177   };
178   NormalizedCacheTy Cache;
179   return transformSubExpr(Denormalize, Pred, SE, Cache, S);
180 }
181