xref: /llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp (revision c5a87a194916aca1bc7f95d36d18d8ccc5efea8a)
1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements utilities for working with "normalized" expressions.
11 // See the comments at the top of ScalarEvolutionNormalization.h for details.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
18 using namespace llvm;
19 
20 namespace {
21 
22 /// TransformKind - Different types of transformations that
23 /// TransformForPostIncUse can do.
24 enum TransformKind {
25   /// Normalize - Normalize according to the given loops.
26   Normalize,
27   /// Denormalize - Perform the inverse transform on the expression with the
28   /// given loop set.
29   Denormalize
30 };
31 
32 /// Hold the state used during post-inc expression transformation, including a
33 /// map of transformed expressions.
34 class PostIncTransform {
35   TransformKind Kind;
36   NormalizePredTy Pred;
37   ScalarEvolution &SE;
38 
39   DenseMap<const SCEV*, const SCEV*> Transformed;
40 
41 public:
42   PostIncTransform(TransformKind kind, NormalizePredTy Pred,
43                    ScalarEvolution &se)
44       : Kind(kind), Pred(Pred), SE(se) {}
45 
46   const SCEV *TransformSubExpr(const SCEV *S);
47 
48 protected:
49   const SCEV *TransformImpl(const SCEV *S);
50 };
51 
52 } // namespace
53 
54 /// Implement post-inc transformation for all valid expression types.
55 const SCEV *PostIncTransform::TransformImpl(const SCEV *S) {
56   if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) {
57     const SCEV *O = X->getOperand();
58     const SCEV *N = TransformSubExpr(O);
59     if (O != N)
60       switch (S->getSCEVType()) {
61       case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType());
62       case scSignExtend: return SE.getSignExtendExpr(N, S->getType());
63       case scTruncate: return SE.getTruncateExpr(N, S->getType());
64       default: llvm_unreachable("Unexpected SCEVCastExpr kind!");
65       }
66     return S;
67   }
68 
69   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
70     // An addrec. This is the interesting part.
71     SmallVector<const SCEV *, 8> Operands;
72 
73     transform(AR->operands(), std::back_inserter(Operands),
74               [&](const SCEV *Op) { return TransformSubExpr(Op); });
75 
76     // Conservatively use AnyWrap until/unless we need FlagNW.
77     const SCEV *Result =
78         SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
79     switch (Kind) {
80     case Normalize:
81       // We want to normalize step expression, because otherwise we might not be
82       // able to denormalize to the original expression.
83       //
84       // Here is an example what will happen if we don't normalize step:
85       //  ORIGINAL ISE:
86       //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
87       //  NORMALIZED ISE:
88       //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
89       //     (100 /u {0,+,1}<%bb16>)}<%bb25>
90       //  DENORMALIZED BACK ISE:
91       //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
92       //     (100 /u {1,+,1}<%bb16>)}<%bb25>
93       //  Note that the initial value changes after normalization +
94       //  denormalization, which isn't correct.
95       if (Pred(AR)) {
96         const SCEV *TransformedStep =
97             TransformSubExpr(AR->getStepRecurrence(SE));
98         Result = SE.getMinusSCEV(Result, TransformedStep);
99       }
100 #if 0
101       // See the comment on the assert above.
102       assert(S == TransformSubExpr(Result, User, OperandValToReplace) &&
103              "SCEV normalization is not invertible!");
104 #endif
105       break;
106     case Denormalize:
107       // Here we want to normalize step expressions for the same reasons, as
108       // stated above.
109       if (Pred(AR)) {
110         const SCEV *TransformedStep =
111             TransformSubExpr(AR->getStepRecurrence(SE));
112         Result = SE.getAddExpr(Result, TransformedStep);
113       }
114       break;
115     }
116     return Result;
117   }
118 
119   if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) {
120     SmallVector<const SCEV *, 8> Operands;
121     bool Changed = false;
122     // Transform each operand.
123     for (SCEVNAryExpr::op_iterator I = X->op_begin(), E = X->op_end();
124          I != E; ++I) {
125       const SCEV *O = *I;
126       const SCEV *N = TransformSubExpr(O);
127       Changed |= N != O;
128       Operands.push_back(N);
129     }
130     // If any operand actually changed, return a transformed result.
131     if (Changed)
132       switch (S->getSCEVType()) {
133       case scAddExpr: return SE.getAddExpr(Operands);
134       case scMulExpr: return SE.getMulExpr(Operands);
135       case scSMaxExpr: return SE.getSMaxExpr(Operands);
136       case scUMaxExpr: return SE.getUMaxExpr(Operands);
137       default: llvm_unreachable("Unexpected SCEVNAryExpr kind!");
138       }
139     return S;
140   }
141 
142   if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) {
143     const SCEV *LO = X->getLHS();
144     const SCEV *RO = X->getRHS();
145     const SCEV *LN = TransformSubExpr(LO);
146     const SCEV *RN = TransformSubExpr(RO);
147     if (LO != LN || RO != RN)
148       return SE.getUDivExpr(LN, RN);
149     return S;
150   }
151 
152   llvm_unreachable("Unexpected SCEV kind!");
153 }
154 
155 /// Manage recursive transformation across an expression DAG. Revisiting
156 /// expressions would lead to exponential recursion.
157 const SCEV *PostIncTransform::TransformSubExpr(const SCEV *S) {
158   if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S))
159     return S;
160 
161   const SCEV *Result = Transformed.lookup(S);
162   if (Result)
163     return Result;
164 
165   Result = TransformImpl(S);
166   Transformed[S] = Result;
167   return Result;
168 }
169 
170 /// Top level driver for transforming an expression DAG into its requested
171 /// post-inc form (either "Normalized" or "Denormalized").
172 static const SCEV *TransformForPostIncUse(TransformKind Kind, const SCEV *S,
173                                           NormalizePredTy Pred,
174                                           ScalarEvolution &SE) {
175   PostIncTransform Transform(Kind, Pred, SE);
176   return Transform.TransformSubExpr(S);
177 }
178 
179 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
180                                          const PostIncLoopSet &Loops,
181                                          ScalarEvolution &SE) {
182   auto Pred = [&](const SCEVAddRecExpr *AR) {
183     return Loops.count(AR->getLoop());
184   };
185   return TransformForPostIncUse(Normalize, S, Pred, SE);
186 }
187 
188 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
189                                            ScalarEvolution &SE) {
190   return TransformForPostIncUse(Normalize, S, Pred, SE);
191 }
192 
193 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
194                                            const PostIncLoopSet &Loops,
195                                            ScalarEvolution &SE) {
196   auto Pred = [&](const SCEVAddRecExpr *AR) {
197     return Loops.count(AR->getLoop());
198   };
199   return TransformForPostIncUse(Denormalize, S, Pred, SE);
200 }
201