xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision eeb776a41f24668898fe32e8f59ba2f793f14711)
1 
2 #include "polly/Support/SCEVValidator.h"
3 
4 #define DEBUG_TYPE "polly-scev-validator"
5 #include "llvm/Support/Debug.h"
6 #include "llvm/Analysis/ScalarEvolution.h"
7 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
8 #include "llvm/Analysis/RegionInfo.h"
9 
10 #include <vector>
11 
12 using namespace llvm;
13 
14 namespace SCEVType {
15   /// @brief The type of a SCEV
16   ///
17   /// To check for the validity of a SCEV we assign to each SCEV a type. The
18   /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19   /// important. The subexpressions of SCEV with a type X can only have a type
20   /// that is smaller or equal than X.
21   enum TYPE {
22              // An integer value.
23              INT,
24 
25              // An expression that is constant during the execution of the Scop,
26              // but that may depend on parameters unknown at compile time.
27              PARAM,
28 
29              // An expression that may change during the execution of the SCoP.
30              IV,
31 
32              // An invalid expression.
33              INVALID
34   };
35 }
36 
37 /// @brief The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// @brief The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// @brief The set of Parameters in the expression.
43   std::vector<const SCEV*> Parameters;
44 
45 public:
46   /// @brief The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   };
51 
52   /// @brief Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   };
56 
57   /// @brief Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.push_back(Expr);
60   };
61 
62   /// @brief Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() {
64     return Type;
65   }
66 
67   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
68   bool isConstant() {
69     return Type == SCEVType::INT || Type == SCEVType::PARAM;
70   }
71 
72   /// @brief Is the analyzed SCEV valid.
73   bool isValid() {
74     return Type != SCEVType::INVALID;
75   }
76 
77   /// @brief Is the analyzed SCEV of Type IV.
78   bool isIV() {
79     return Type == SCEVType::IV;
80   }
81 
82   /// @brief Is the analyzed SCEV of Type INT.
83   bool isINT() {
84     return Type == SCEVType::INT;
85   }
86 
87   /// @brief Is the analyzed SCEV of Type PARAM.
88   bool isPARAM() {
89     return Type == SCEVType::PARAM;
90   }
91 
92   /// @brief Get the parameters of this validator result.
93   std::vector<const SCEV*> getParameters() {
94     return Parameters;
95   }
96 
97   /// @brief Add the parameters of Source to this result.
98   void addParamsFrom(class ValidatorResult &Source) {
99     Parameters.insert(Parameters.end(),
100                       Source.Parameters.begin(),
101                       Source.Parameters.end());
102   }
103 
104   /// @brief Merge a result.
105   ///
106   /// This means to merge the parameters and to set the Type to the most
107   /// specific Type that matches both.
108   void merge(class ValidatorResult &ToMerge) {
109     Type = std::max(Type, ToMerge.Type);
110     addParamsFrom(ToMerge);
111   }
112 
113   void print(raw_ostream &OS) {
114     switch (Type) {
115       case SCEVType::INT:
116         OS << "SCEVType::INT";
117       break;
118       case SCEVType::PARAM:
119         OS << "SCEVType::PARAM";
120       break;
121       case SCEVType::IV:
122         OS << "SCEVType::IV";
123       break;
124       case SCEVType::INVALID:
125         OS << "SCEVType::INVALID";
126       break;
127     }
128   }
129 };
130 
131 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
132   VR.print(OS);
133   return OS;
134 }
135 
136 /// Check if a SCEV is valid in a SCoP.
137 struct SCEVValidator
138   : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
139 private:
140   const Region *R;
141   ScalarEvolution &SE;
142   const Value *BaseAddress;
143 
144 public:
145   SCEVValidator(const Region *R, ScalarEvolution &SE,
146                 const Value *BaseAddress) : R(R), SE(SE),
147     BaseAddress(BaseAddress) {};
148 
149   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
150     return ValidatorResult(SCEVType::INT);
151   }
152 
153   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
154     ValidatorResult Op = visit(Expr->getOperand());
155 
156     switch (Op.getType()) {
157       case SCEVType::INT:
158       case SCEVType::PARAM:
159        // We currently do not represent a truncate expression as an affine
160        // expression. If it is constant during Scop execution, we treat it as a
161        // parameter.
162         return ValidatorResult(SCEVType::PARAM, Expr);
163       case SCEVType::IV:
164         DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
165         return ValidatorResult(SCEVType::INVALID);
166       case SCEVType::INVALID:
167         return Op;
168     }
169 
170     llvm_unreachable("Unknown SCEVType");
171   }
172 
173   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
174     ValidatorResult Op = visit(Expr->getOperand());
175 
176     switch (Op.getType()) {
177       case SCEVType::INT:
178       case SCEVType::PARAM:
179        // We currently do not represent a truncate expression as an affine
180        // expression. If it is constant during Scop execution, we treat it as a
181        // parameter.
182         return ValidatorResult(SCEVType::PARAM, Expr);
183       case SCEVType::IV:
184         DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
185         return ValidatorResult(SCEVType::INVALID);
186       case SCEVType::INVALID:
187         return Op;
188     }
189 
190     llvm_unreachable("Unknown SCEVType");
191   }
192 
193   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
194     // We currently allow only signed SCEV expressions. In the case of a
195     // signed value, a sign extend is a noop.
196     //
197     // TODO: Reconsider this when we add support for unsigned values.
198     return visit(Expr->getOperand());
199   }
200 
201   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
202     ValidatorResult Return(SCEVType::INT);
203 
204     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
205       ValidatorResult Op = visit(Expr->getOperand(i));
206       Return.merge(Op);
207 
208       // Early exit.
209       if (!Return.isValid())
210         break;
211     }
212 
213     // TODO: Check for NSW and NUW.
214     return Return;
215   }
216 
217   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
218     ValidatorResult Return(SCEVType::INT);
219 
220     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
221       ValidatorResult Op = visit(Expr->getOperand(i));
222 
223       if (Op.isINT())
224         continue;
225 
226       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT() ) {
227         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
228                      << "\tExpr: " << *Expr << "\n"
229                      << "\tPrevious expression type: " << Return << "\n"
230                      << "\tNext operand (" << Op << "): "
231                      << *Expr->getOperand(i) << "\n");
232 
233         return ValidatorResult(SCEVType::INVALID);
234       }
235 
236       Return.merge(Op);
237     }
238 
239     // TODO: Check for NSW and NUW.
240     return Return;
241   }
242 
243   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
244     ValidatorResult LHS = visit(Expr->getLHS());
245     ValidatorResult RHS = visit(Expr->getRHS());
246 
247     // We currently do not represent an unsigned division as an affine
248     // expression. If the division is constant during Scop execution we treat it
249     // as a parameter, otherwise we bail out.
250     if (LHS.isConstant() && RHS.isConstant())
251       return ValidatorResult(SCEVType::PARAM, Expr);
252 
253     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
254     return ValidatorResult(SCEVType::INVALID);
255   }
256 
257   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
258     if (!Expr->isAffine()) {
259       DEBUG(dbgs() << "INVALID: AddRec is not affine");
260       return ValidatorResult(SCEVType::INVALID);
261     }
262 
263     ValidatorResult Start = visit(Expr->getStart());
264     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
265 
266     if (!Start.isValid())
267       return Start;
268 
269     if (!Recurrence.isValid())
270       return Recurrence;
271 
272     if (R->contains(Expr->getLoop())) {
273       if (Recurrence.isINT()) {
274         ValidatorResult Result(SCEVType::IV);
275         Result.addParamsFrom(Start);
276         return Result;
277       }
278 
279       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
280                       "recurrence part");
281       return ValidatorResult(SCEVType::INVALID);
282     }
283 
284     assert (Start.isConstant() && Recurrence.isConstant()
285             && "Expected 'Start' and 'Recurrence' to be constant");
286     return ValidatorResult(SCEVType::PARAM, Expr);
287   }
288 
289   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
290     ValidatorResult Return(SCEVType::INT, Expr);
291 
292     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
293       ValidatorResult Op = visit(Expr->getOperand(i));
294 
295       if (!Op.isValid())
296         return Op;
297 
298       Return.merge(Op);
299     }
300 
301     return Return;
302   }
303 
304   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
305     // We do not support unsigned operations. If 'Expr' is constant during Scop
306     // execution we treat this as a parameter, otherwise we bail out.
307     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
308       ValidatorResult Op = visit(Expr->getOperand(i));
309 
310       if (!Op.isConstant()) {
311         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
312         return ValidatorResult(SCEVType::INVALID);
313       }
314     }
315 
316     return ValidatorResult(SCEVType::PARAM, Expr);
317   }
318 
319   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
320     Value *V = Expr->getValue();
321 
322     // We currently only support integer types. It may be useful to support
323     // pointer types, e.g. to support code like:
324     //
325     //   if (A)
326     //     A[i] = 1;
327     //
328     // See test/CodeGen/20120316-InvalidCast.ll
329     if (!Expr->getType()->isIntegerTy()) {
330       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer type");
331       return ValidatorResult(SCEVType::INVALID);
332     }
333 
334     if (isa<UndefValue>(V)) {
335       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
336       return ValidatorResult(SCEVType::INVALID);
337     }
338 
339     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
340       if (R->contains(I)) {
341         DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
342                         "within the region\n");
343         return ValidatorResult(SCEVType::INVALID);
344       }
345 
346     if (BaseAddress == V) {
347       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
348       return ValidatorResult(SCEVType::INVALID);
349     }
350 
351     return ValidatorResult(SCEVType::PARAM, Expr);
352   }
353 };
354 
355 namespace polly {
356   bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
357                     const Value *BaseAddress) {
358     if (isa<SCEVCouldNotCompute>(Expr))
359       return false;
360 
361     SCEVValidator Validator(R, SE, BaseAddress);
362     DEBUG(
363       dbgs() << "\n";
364       dbgs() << "Expr: " << *Expr << "\n";
365       dbgs() << "Region: " << R->getNameStr() << "\n";
366       dbgs() << " -> ");
367 
368     ValidatorResult Result = Validator.visit(Expr);
369 
370     DEBUG(
371       if (Result.isValid())
372         dbgs() << "VALID\n";
373       dbgs() << "\n";
374     );
375 
376     return Result.isValid();
377   }
378 
379   std::vector<const SCEV*> getParamsInAffineExpr(const Region *R,
380                                                  const SCEV *Expr,
381                                                  ScalarEvolution &SE,
382                                                  const Value *BaseAddress) {
383     if (isa<SCEVCouldNotCompute>(Expr))
384       return std::vector<const SCEV*>();
385 
386     SCEVValidator Validator(R, SE, BaseAddress);
387     ValidatorResult Result = Validator.visit(Expr);
388 
389     return Result.getParameters();
390   }
391 }
392 
393 
394