xref: /llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp (revision 478cd98b22cd1e645cd15a6475773b8cf5a857ee)
1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements utilities for working with "normalized" expressions.
11 // See the comments at the top of ScalarEvolutionNormalization.h for details.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
18 using namespace llvm;
19 
20 namespace {
21 
22 /// TransformKind - Different types of transformations that
23 /// TransformForPostIncUse can do.
24 enum TransformKind {
25   /// Normalize - Normalize according to the given loops.
26   Normalize,
27   /// Denormalize - Perform the inverse transform on the expression with the
28   /// given loop set.
29   Denormalize
30 };
31 
32 /// Hold the state used during post-inc expression transformation, including a
33 /// map of transformed expressions.
34 class PostIncTransform {
35   TransformKind Kind;
36   NormalizePredTy Pred;
37   ScalarEvolution &SE;
38 
39   DenseMap<const SCEV*, const SCEV*> Transformed;
40 
41 public:
42   PostIncTransform(TransformKind kind, NormalizePredTy Pred,
43                    ScalarEvolution &se)
44       : Kind(kind), Pred(Pred), SE(se) {}
45 
46   const SCEV *TransformSubExpr(const SCEV *S);
47 
48 protected:
49   const SCEV *TransformImpl(const SCEV *S);
50 };
51 
52 } // namespace
53 
54 /// Implement post-inc transformation for all valid expression types.
55 const SCEV *PostIncTransform::TransformImpl(const SCEV *S) {
56   if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) {
57     const SCEV *O = X->getOperand();
58     const SCEV *N = TransformSubExpr(O);
59     if (O != N)
60       switch (S->getSCEVType()) {
61       case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType());
62       case scSignExtend: return SE.getSignExtendExpr(N, S->getType());
63       case scTruncate: return SE.getTruncateExpr(N, S->getType());
64       default: llvm_unreachable("Unexpected SCEVCastExpr kind!");
65       }
66     return S;
67   }
68 
69   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
70     // An addrec. This is the interesting part.
71     SmallVector<const SCEV *, 8> Operands;
72 
73     transform(AR->operands(), std::back_inserter(Operands),
74               [&](const SCEV *Op) { return TransformSubExpr(Op); });
75 
76     // Conservatively use AnyWrap until/unless we need FlagNW.
77     const SCEV *Result =
78         SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
79     switch (Kind) {
80     case Normalize:
81       // We want to normalize step expression, because otherwise we might not be
82       // able to denormalize to the original expression.
83       //
84       // Here is an example what will happen if we don't normalize step:
85       //  ORIGINAL ISE:
86       //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
87       //  NORMALIZED ISE:
88       //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
89       //     (100 /u {0,+,1}<%bb16>)}<%bb25>
90       //  DENORMALIZED BACK ISE:
91       //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
92       //     (100 /u {1,+,1}<%bb16>)}<%bb25>
93       //  Note that the initial value changes after normalization +
94       //  denormalization, which isn't correct.
95       if (Pred(AR)) {
96         const SCEV *TransformedStep =
97             TransformSubExpr(AR->getStepRecurrence(SE));
98         Result = SE.getMinusSCEV(Result, TransformedStep);
99       }
100 #if 0
101       // See the comment on the assert above.
102       assert(S == TransformSubExpr(Result, User, OperandValToReplace) &&
103              "SCEV normalization is not invertible!");
104 #endif
105       break;
106     case Denormalize:
107       // Here we want to normalize step expressions for the same reasons, as
108       // stated above.
109       if (Pred(AR)) {
110         const SCEV *TransformedStep =
111             TransformSubExpr(AR->getStepRecurrence(SE));
112         Result = SE.getAddExpr(Result, TransformedStep);
113       }
114       break;
115     }
116     return Result;
117   }
118 
119   if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) {
120     SmallVector<const SCEV *, 8> Operands;
121     bool Changed = false;
122     // Transform each operand.
123     for (auto *O : X->operands()) {
124       const SCEV *N = TransformSubExpr(O);
125       Changed |= N != O;
126       Operands.push_back(N);
127     }
128     // If any operand actually changed, return a transformed result.
129     if (Changed)
130       switch (S->getSCEVType()) {
131       case scAddExpr: return SE.getAddExpr(Operands);
132       case scMulExpr: return SE.getMulExpr(Operands);
133       case scSMaxExpr: return SE.getSMaxExpr(Operands);
134       case scUMaxExpr: return SE.getUMaxExpr(Operands);
135       default: llvm_unreachable("Unexpected SCEVNAryExpr kind!");
136       }
137     return S;
138   }
139 
140   if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) {
141     const SCEV *LO = X->getLHS();
142     const SCEV *RO = X->getRHS();
143     const SCEV *LN = TransformSubExpr(LO);
144     const SCEV *RN = TransformSubExpr(RO);
145     if (LO != LN || RO != RN)
146       return SE.getUDivExpr(LN, RN);
147     return S;
148   }
149 
150   llvm_unreachable("Unexpected SCEV kind!");
151 }
152 
153 /// Manage recursive transformation across an expression DAG. Revisiting
154 /// expressions would lead to exponential recursion.
155 const SCEV *PostIncTransform::TransformSubExpr(const SCEV *S) {
156   if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S))
157     return S;
158 
159   const SCEV *Result = Transformed.lookup(S);
160   if (Result)
161     return Result;
162 
163   Result = TransformImpl(S);
164   Transformed[S] = Result;
165   return Result;
166 }
167 
168 /// Top level driver for transforming an expression DAG into its requested
169 /// post-inc form (either "Normalized" or "Denormalized").
170 static const SCEV *TransformForPostIncUse(TransformKind Kind, const SCEV *S,
171                                           NormalizePredTy Pred,
172                                           ScalarEvolution &SE) {
173   PostIncTransform Transform(Kind, Pred, SE);
174   return Transform.TransformSubExpr(S);
175 }
176 
177 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
178                                          const PostIncLoopSet &Loops,
179                                          ScalarEvolution &SE) {
180   auto Pred = [&](const SCEVAddRecExpr *AR) {
181     return Loops.count(AR->getLoop());
182   };
183   return TransformForPostIncUse(Normalize, S, Pred, SE);
184 }
185 
186 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
187                                            ScalarEvolution &SE) {
188   return TransformForPostIncUse(Normalize, S, Pred, SE);
189 }
190 
191 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
192                                            const PostIncLoopSet &Loops,
193                                            ScalarEvolution &SE) {
194   auto Pred = [&](const SCEVAddRecExpr *AR) {
195     return Loops.count(AR->getLoop());
196   };
197   return TransformForPostIncUse(Denormalize, S, Pred, SE);
198 }
199