xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 3ec2abc5fb2e42241ec454bc24c80b9e34767b05)
1 
2 #include "polly/Support/SCEVValidator.h"
3 
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 
8 #include <vector>
9 
10 using namespace llvm;
11 
12 namespace SCEVType {
13   /// @brief The type of a SCEV
14   ///
15   /// To check for the validity of a SCEV we assign to each SCEV a type. The
16   /// possible types are INT, PARAM, IV and INVALID. The order of the types is
17   /// important. The subexpressions of SCEV with a type X can only have a type
18   /// that is smaller or equal than X.
19   enum TYPE {
20              // An integer value.
21              INT,
22 
23              // An expression that is constant during the execution of the Scop,
24              // but that may depend on parameters unknown at compile time.
25              PARAM,
26 
27              // An expression that may change during the execution of the SCoP.
28              IV,
29 
30              // An invalid expression.
31              INVALID
32   };
33 }
34 
35 /// @brief The result the validator returns for a SCEV expression.
36 class ValidatorResult {
37   /// @brief The type of the expression
38   SCEVType::TYPE Type;
39 
40   /// @brief The set of Parameters in the expression.
41   std::vector<const SCEV*> Parameters;
42 
43 public:
44 
45   /// @brief Create an invalid result.
46   ValidatorResult() : Type(SCEVType::INVALID) {};
47 
48   /// @brief The copy constructor
49   ValidatorResult(const ValidatorResult &Source) {
50     Type = Source.Type;
51     Parameters = Source.Parameters;
52   };
53 
54   /// @brief Construct a result with a certain type and no parameters.
55   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
56     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
57   };
58 
59   /// @brief Construct a result with a certain type and a single parameter.
60   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
61     Parameters.push_back(Expr);
62   };
63 
64   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
65   bool isConstant() {
66     return Type == SCEVType::INT || Type == SCEVType::PARAM;
67   }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() {
71     return Type != SCEVType::INVALID;
72   }
73 
74   /// @brief Is the analyzed SCEV of Type IV.
75   bool isIV() {
76     return Type == SCEVType::IV;
77   }
78 
79   /// @brief Is the analyzed SCEV of Type INT.
80   bool isINT() {
81     return Type == SCEVType::INT;
82   }
83 
84   /// @brief Get the parameters of this validator result.
85   std::vector<const SCEV*> getParameters() {
86     return Parameters;
87   }
88 
89   /// @brief Add the parameters of Source to this result.
90   void addParamsFrom(class ValidatorResult &Source) {
91     Parameters.insert(Parameters.end(),
92                       Source.Parameters.begin(),
93                       Source.Parameters.end());
94   }
95 
96   /// @brief Merge a result.
97   ///
98   /// This means to merge the parameters and to set the Type to the most
99   /// specific Type that matches both.
100   void merge(class ValidatorResult &ToMerge) {
101     Type = std::max(Type, ToMerge.Type);
102     addParamsFrom(ToMerge);
103   }
104 
105   void print(raw_ostream &OS) {
106     switch (Type) {
107       case SCEVType::INT:
108         OS << "SCEVType::INT\n";
109       break;
110       case SCEVType::PARAM:
111         OS << "SCEVType::PARAM\n";
112       break;
113       case SCEVType::IV:
114         OS << "SCEVType::IV\n";
115       break;
116       case SCEVType::INVALID:
117         OS << "SCEVType::INVALID\n";
118       break;
119     }
120   }
121 };
122 
123 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
124   VR.print(OS);
125   return OS;
126 }
127 
128 /// Check if a SCEV is valid in a SCoP.
129 struct SCEVValidator
130   : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
131 private:
132   const Region *R;
133   ScalarEvolution &SE;
134   const Value *BaseAddress;
135 
136 public:
137   SCEVValidator(const Region *R, ScalarEvolution &SE,
138                 const Value *BaseAddress) : R(R), SE(SE),
139     BaseAddress(BaseAddress) {};
140 
141   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
142     return ValidatorResult(SCEVType::INT);
143   }
144 
145   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
146     ValidatorResult Op = visit(Expr->getOperand());
147 
148     // We currently do not represent a truncate expression as an affine
149     // expression. If it is constant during Scop execution, we treat it as a
150     // parameter, otherwise we bail out.
151     if (Op.isConstant())
152       return ValidatorResult(SCEVType::PARAM, Expr);
153 
154     return ValidatorResult(SCEVType::INVALID);
155   }
156 
157   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
158     ValidatorResult Op = visit(Expr->getOperand());
159 
160     // We currently do not represent a zero extend expression as an affine
161     // expression. If it is constant during Scop execution, we treat it as a
162     // parameter, otherwise we bail out.
163     if (Op.isConstant())
164       return ValidatorResult(SCEVType::PARAM, Expr);
165 
166     return ValidatorResult(SCEVType::INVALID);
167   }
168 
169   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
170     // We currently allow only signed SCEV expressions. In the case of a
171     // signed value, a sign extend is a noop.
172     //
173     // TODO: Reconsider this when we add support for unsigned values.
174     return visit(Expr->getOperand());
175   }
176 
177   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
178     ValidatorResult Return(SCEVType::INT);
179 
180     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
181       ValidatorResult Op = visit(Expr->getOperand(i));
182 
183       if (!Op.isValid())
184         return ValidatorResult(SCEVType::INVALID);
185 
186       Return.merge(Op);
187     }
188 
189     // TODO: Check for NSW and NUW.
190     return Return;
191   }
192 
193   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
194     ValidatorResult Return(SCEVType::INT);
195 
196     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
197       ValidatorResult Op = visit(Expr->getOperand(i));
198 
199       if (Op.isINT())
200         continue;
201 
202       if (!Op.isValid() || !Return.isINT())
203         return ValidatorResult(SCEVType::INVALID);
204 
205       Return.merge(Op);
206     }
207 
208     // TODO: Check for NSW and NUW.
209     return Return;
210   }
211 
212   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
213     ValidatorResult LHS = visit(Expr->getLHS());
214     ValidatorResult RHS = visit(Expr->getRHS());
215 
216     // We currently do not represent an unsigned devision as an affine
217     // expression. If the division is constant during Scop execution we treat it
218     // as a parameter, otherwise we bail out.
219     if (LHS.isConstant() && RHS.isConstant())
220       return ValidatorResult(SCEVType::PARAM, Expr);
221 
222     return ValidatorResult(SCEVType::INVALID);
223   }
224 
225   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
226     if (!Expr->isAffine())
227       return ValidatorResult(SCEVType::INVALID);
228 
229     ValidatorResult Start = visit(Expr->getStart());
230     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
231 
232     if (!Start.isValid() || !Recurrence.isConstant())
233       return ValidatorResult(SCEVType::INVALID);
234 
235     if (R->contains(Expr->getLoop())) {
236       if (Recurrence.isINT()) {
237         ValidatorResult Result(SCEVType::IV);
238         Result.addParamsFrom(Start);
239         return Result;
240       }
241 
242       return ValidatorResult(SCEVType::INVALID);
243     }
244 
245     if (Start.isConstant())
246       return ValidatorResult(SCEVType::PARAM, Expr);
247 
248     return ValidatorResult(SCEVType::INVALID);
249   }
250 
251   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
252     ValidatorResult Return(SCEVType::INT, Expr);
253 
254     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
255       ValidatorResult Op = visit(Expr->getOperand(i));
256 
257       if (!Op.isValid())
258         return ValidatorResult(SCEVType::INVALID);
259 
260       Return.merge(Op);
261     }
262 
263     return Return;
264   }
265 
266   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
267     // We do not support unsigned operations. If 'Expr' is constant during Scop
268     // execution we treat this as a parameter, otherwise we bail out.
269     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
270       ValidatorResult Op = visit(Expr->getOperand(i));
271 
272       if (!Op.isConstant())
273         return ValidatorResult(SCEVType::INVALID);
274     }
275 
276     return ValidatorResult(SCEVType::PARAM, Expr);
277   }
278 
279   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
280     Value *V = Expr->getValue();
281 
282     // We currently only support integer types. It may be useful to support
283     // pointer types, e.g. to support code like:
284     //
285     //   if (A)
286     //     A[i] = 1;
287     //
288     // See test/CodeGen/20120316-InvalidCast.ll
289     if (!Expr->getType()->isIntegerTy())
290       return ValidatorResult(SCEVType::INVALID);
291 
292     if (isa<UndefValue>(V))
293       return ValidatorResult(SCEVType::INVALID);
294 
295     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
296       if (R->contains(I))
297         return ValidatorResult(SCEVType::INVALID);
298 
299     if (BaseAddress == V)
300       return ValidatorResult(SCEVType::INVALID);
301 
302     return ValidatorResult(SCEVType::PARAM, Expr);
303   }
304 };
305 
306 namespace polly {
307   bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
308                     const Value *BaseAddress) {
309     if (isa<SCEVCouldNotCompute>(Expr))
310       return false;
311 
312     SCEVValidator Validator(R, SE, BaseAddress);
313     ValidatorResult Result = Validator.visit(Expr);
314 
315     return Result.isValid();
316   }
317 
318   std::vector<const SCEV*> getParamsInAffineExpr(const Region *R,
319                                                  const SCEV *Expr,
320                                                  ScalarEvolution &SE,
321                                                  const Value *BaseAddress) {
322     if (isa<SCEVCouldNotCompute>(Expr))
323       return std::vector<const SCEV*>();
324 
325     SCEVValidator Validator(R, SE, BaseAddress);
326     ValidatorResult Result = Validator.visit(Expr);
327 
328     return Result.getParameters();
329   }
330 }
331 
332 
333