xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 55bc4c076706b6410a6bae84210ab32b95f272c9)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 #include "llvm/Support/Debug.h"
8 
9 #include <vector>
10 
11 using namespace llvm;
12 
13 #define DEBUG_TYPE "polly-scev-validator"
14 
15 namespace SCEVType {
16 /// @brief The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23   // An integer value.
24   INT,
25 
26   // An expression that is constant during the execution of the Scop,
27   // but that may depend on parameters unknown at compile time.
28   PARAM,
29 
30   // An expression that may change during the execution of the SCoP.
31   IV,
32 
33   // An invalid expression.
34   INVALID
35 };
36 }
37 
38 /// @brief The result the validator returns for a SCEV expression.
39 class ValidatorResult {
40   /// @brief The type of the expression
41   SCEVType::TYPE Type;
42 
43   /// @brief The set of Parameters in the expression.
44   std::vector<const SCEV *> Parameters;
45 
46 public:
47   /// @brief The copy constructor
48   ValidatorResult(const ValidatorResult &Source) {
49     Type = Source.Type;
50     Parameters = Source.Parameters;
51   }
52 
53   /// @brief Construct a result with a certain type and no parameters.
54   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
56   }
57 
58   /// @brief Construct a result with a certain type and a single parameter.
59   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60     Parameters.push_back(Expr);
61   }
62 
63   /// @brief Get the type of the ValidatorResult.
64   SCEVType::TYPE getType() { return Type; }
65 
66   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
67   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() { return Type != SCEVType::INVALID; }
71 
72   /// @brief Is the analyzed SCEV of Type IV.
73   bool isIV() { return Type == SCEVType::IV; }
74 
75   /// @brief Is the analyzed SCEV of Type INT.
76   bool isINT() { return Type == SCEVType::INT; }
77 
78   /// @brief Is the analyzed SCEV of Type PARAM.
79   bool isPARAM() { return Type == SCEVType::PARAM; }
80 
81   /// @brief Get the parameters of this validator result.
82   std::vector<const SCEV *> getParameters() { return Parameters; }
83 
84   /// @brief Add the parameters of Source to this result.
85   void addParamsFrom(class ValidatorResult &Source) {
86     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
87                       Source.Parameters.end());
88   }
89 
90   /// @brief Merge a result.
91   ///
92   /// This means to merge the parameters and to set the Type to the most
93   /// specific Type that matches both.
94   void merge(class ValidatorResult &ToMerge) {
95     Type = std::max(Type, ToMerge.Type);
96     addParamsFrom(ToMerge);
97   }
98 
99   void print(raw_ostream &OS) {
100     switch (Type) {
101     case SCEVType::INT:
102       OS << "SCEVType::INT";
103       break;
104     case SCEVType::PARAM:
105       OS << "SCEVType::PARAM";
106       break;
107     case SCEVType::IV:
108       OS << "SCEVType::IV";
109       break;
110     case SCEVType::INVALID:
111       OS << "SCEVType::INVALID";
112       break;
113     }
114   }
115 };
116 
117 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
118   VR.print(OS);
119   return OS;
120 }
121 
122 /// Check if a SCEV is valid in a SCoP.
123 struct SCEVValidator
124     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
125 private:
126   const Region *R;
127   ScalarEvolution &SE;
128   const Value *BaseAddress;
129 
130 public:
131   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
132       : R(R), SE(SE), BaseAddress(BaseAddress) {}
133 
134   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135     return ValidatorResult(SCEVType::INT);
136   }
137 
138   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
139     ValidatorResult Op = visit(Expr->getOperand());
140 
141     switch (Op.getType()) {
142     case SCEVType::INT:
143     case SCEVType::PARAM:
144       // We currently do not represent a truncate expression as an affine
145       // expression. If it is constant during Scop execution, we treat it as a
146       // parameter.
147       return ValidatorResult(SCEVType::PARAM, Expr);
148     case SCEVType::IV:
149       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
150       return ValidatorResult(SCEVType::INVALID);
151     case SCEVType::INVALID:
152       return Op;
153     }
154 
155     llvm_unreachable("Unknown SCEVType");
156   }
157 
158   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
159     ValidatorResult Op = visit(Expr->getOperand());
160 
161     switch (Op.getType()) {
162     case SCEVType::INT:
163     case SCEVType::PARAM:
164       // We currently do not represent a truncate expression as an affine
165       // expression. If it is constant during Scop execution, we treat it as a
166       // parameter.
167       return ValidatorResult(SCEVType::PARAM, Expr);
168     case SCEVType::IV:
169       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
170       return ValidatorResult(SCEVType::INVALID);
171     case SCEVType::INVALID:
172       return Op;
173     }
174 
175     llvm_unreachable("Unknown SCEVType");
176   }
177 
178   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
179     // We currently allow only signed SCEV expressions. In the case of a
180     // signed value, a sign extend is a noop.
181     //
182     // TODO: Reconsider this when we add support for unsigned values.
183     return visit(Expr->getOperand());
184   }
185 
186   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
187     ValidatorResult Return(SCEVType::INT);
188 
189     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
190       ValidatorResult Op = visit(Expr->getOperand(i));
191       Return.merge(Op);
192 
193       // Early exit.
194       if (!Return.isValid())
195         break;
196     }
197 
198     // TODO: Check for NSW and NUW.
199     return Return;
200   }
201 
202   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
203     ValidatorResult Return(SCEVType::INT);
204 
205     bool HasMultipleParams = false;
206 
207     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
208       ValidatorResult Op = visit(Expr->getOperand(i));
209 
210       if (Op.isINT())
211         continue;
212 
213       if (Op.isPARAM() && Return.isPARAM()) {
214         HasMultipleParams = true;
215         continue;
216       }
217 
218       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
219         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
220                      << "\tExpr: " << *Expr << "\n"
221                      << "\tPrevious expression type: " << Return << "\n"
222                      << "\tNext operand (" << Op
223                      << "): " << *Expr->getOperand(i) << "\n");
224 
225         return ValidatorResult(SCEVType::INVALID);
226       }
227 
228       Return.merge(Op);
229     }
230 
231     if (HasMultipleParams)
232       return ValidatorResult(SCEVType::PARAM, Expr);
233 
234     // TODO: Check for NSW and NUW.
235     return Return;
236   }
237 
238   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
239     ValidatorResult LHS = visit(Expr->getLHS());
240     ValidatorResult RHS = visit(Expr->getRHS());
241 
242     // We currently do not represent an unsigned division as an affine
243     // expression. If the division is constant during Scop execution we treat it
244     // as a parameter, otherwise we bail out.
245     if (LHS.isConstant() && RHS.isConstant())
246       return ValidatorResult(SCEVType::PARAM, Expr);
247 
248     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
249     return ValidatorResult(SCEVType::INVALID);
250   }
251 
252   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
253     if (!Expr->isAffine()) {
254       DEBUG(dbgs() << "INVALID: AddRec is not affine");
255       return ValidatorResult(SCEVType::INVALID);
256     }
257 
258     ValidatorResult Start = visit(Expr->getStart());
259     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
260 
261     if (!Start.isValid())
262       return Start;
263 
264     if (!Recurrence.isValid())
265       return Recurrence;
266 
267     if (R->contains(Expr->getLoop())) {
268       if (Recurrence.isINT()) {
269         ValidatorResult Result(SCEVType::IV);
270         Result.addParamsFrom(Start);
271         return Result;
272       }
273 
274       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
275                       "recurrence part");
276       return ValidatorResult(SCEVType::INVALID);
277     }
278 
279     assert(Start.isConstant() && Recurrence.isConstant() &&
280            "Expected 'Start' and 'Recurrence' to be constant");
281 
282     // Directly generate ValidatorResult for Expr if 'start' is zero.
283     if (Expr->getStart()->isZero())
284       return ValidatorResult(SCEVType::PARAM, Expr);
285 
286     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
287     // if 'start' is not zero.
288     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
289         SE.getConstant(Expr->getStart()->getType(), 0),
290         Expr->getStepRecurrence(SE), Expr->getLoop(), SCEV::FlagAnyWrap);
291 
292     ValidatorResult ZeroStartResult =
293         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
294     ZeroStartResult.addParamsFrom(Start);
295 
296     return ZeroStartResult;
297   }
298 
299   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
300     ValidatorResult Return(SCEVType::INT);
301 
302     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
303       ValidatorResult Op = visit(Expr->getOperand(i));
304 
305       if (!Op.isValid())
306         return Op;
307 
308       Return.merge(Op);
309     }
310 
311     return Return;
312   }
313 
314   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
315     // We do not support unsigned operations. If 'Expr' is constant during Scop
316     // execution we treat this as a parameter, otherwise we bail out.
317     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
318       ValidatorResult Op = visit(Expr->getOperand(i));
319 
320       if (!Op.isConstant()) {
321         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
322         return ValidatorResult(SCEVType::INVALID);
323       }
324     }
325 
326     return ValidatorResult(SCEVType::PARAM, Expr);
327   }
328 
329   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
330     Value *V = Expr->getValue();
331 
332     // We currently only support integer types. It may be useful to support
333     // pointer types, e.g. to support code like:
334     //
335     //   if (A)
336     //     A[i] = 1;
337     //
338     // See test/CodeGen/20120316-InvalidCast.ll
339     if (!(Expr->getType()->isIntegerTy() || Expr->getType()->isPointerTy())) {
340       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer type");
341       return ValidatorResult(SCEVType::INVALID);
342     }
343 
344     if (isa<UndefValue>(V)) {
345       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
346       return ValidatorResult(SCEVType::INVALID);
347     }
348 
349     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
350       if (R->contains(I)) {
351         DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
352                         "within the region\n");
353         return ValidatorResult(SCEVType::INVALID);
354       }
355 
356     if (BaseAddress == V) {
357       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
358       return ValidatorResult(SCEVType::INVALID);
359     }
360 
361     return ValidatorResult(SCEVType::PARAM, Expr);
362   }
363 };
364 
365 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
366 ///
367 struct SCEVInRegionDependences
368     : public SCEVVisitor<SCEVInRegionDependences, bool> {
369 public:
370   /// Returns true when the SCEV has SSA names defined in region R.
371   static bool hasDependences(const SCEV *S, const Region *R) {
372     SCEVInRegionDependences Ignore(R);
373     return Ignore.visit(S);
374   }
375 
376   SCEVInRegionDependences(const Region *R) : R(R) {}
377 
378   bool visit(const SCEV *Expr) {
379     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
380   }
381 
382   bool visitConstant(const SCEVConstant *Constant) { return false; }
383 
384   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
385     return visit(Expr->getOperand());
386   }
387 
388   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
389     return visit(Expr->getOperand());
390   }
391 
392   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
393     return visit(Expr->getOperand());
394   }
395 
396   bool visitAddExpr(const SCEVAddExpr *Expr) {
397     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
398       if (visit(Expr->getOperand(i)))
399         return true;
400 
401     return false;
402   }
403 
404   bool visitMulExpr(const SCEVMulExpr *Expr) {
405     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
406       if (visit(Expr->getOperand(i)))
407         return true;
408 
409     return false;
410   }
411 
412   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
413     if (visit(Expr->getLHS()))
414       return true;
415 
416     if (visit(Expr->getRHS()))
417       return true;
418 
419     return false;
420   }
421 
422   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
423     if (visit(Expr->getStart()))
424       return true;
425 
426     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
427       if (visit(Expr->getOperand(i)))
428         return true;
429 
430     return false;
431   }
432 
433   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
434     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
435       if (visit(Expr->getOperand(i)))
436         return true;
437 
438     return false;
439   }
440 
441   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
442     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
443       if (visit(Expr->getOperand(i)))
444         return true;
445 
446     return false;
447   }
448 
449   bool visitUnknown(const SCEVUnknown *Expr) {
450     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
451 
452     // Return true when Inst is defined inside the region R.
453     if (Inst && R->contains(Inst))
454       return true;
455 
456     return false;
457   }
458 
459 private:
460   const Region *R;
461 };
462 
463 namespace polly {
464 /// Find all loops referenced in SCEVAddRecExprs.
465 class SCEVFindLoops {
466   SetVector<const Loop *> &Loops;
467 
468 public:
469   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
470 
471   bool follow(const SCEV *S) {
472     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
473       Loops.insert(AddRec->getLoop());
474     return true;
475   }
476   bool isDone() { return false; }
477 };
478 
479 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
480   SCEVFindLoops FindLoops(Loops);
481   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
482   ST.visitAll(Expr);
483 }
484 
485 /// Find all values referenced in SCEVUnknowns.
486 class SCEVFindValues {
487   SetVector<Value *> &Values;
488 
489 public:
490   SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {}
491 
492   bool follow(const SCEV *S) {
493     if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S))
494       Values.insert(Unknown->getValue());
495     return true;
496   }
497   bool isDone() { return false; }
498 };
499 
500 void findValues(const SCEV *Expr, SetVector<Value *> &Values) {
501   SCEVFindValues FindValues(Values);
502   SCEVTraversal<SCEVFindValues> ST(FindValues);
503   ST.visitAll(Expr);
504 }
505 
506 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
507   return SCEVInRegionDependences::hasDependences(Expr, R);
508 }
509 
510 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
511                   const Value *BaseAddress) {
512   if (isa<SCEVCouldNotCompute>(Expr))
513     return false;
514 
515   SCEVValidator Validator(R, SE, BaseAddress);
516   DEBUG({
517     dbgs() << "\n";
518     dbgs() << "Expr: " << *Expr << "\n";
519     dbgs() << "Region: " << R->getNameStr() << "\n";
520     dbgs() << " -> ";
521   });
522 
523   ValidatorResult Result = Validator.visit(Expr);
524 
525   DEBUG({
526     if (Result.isValid())
527       dbgs() << "VALID\n";
528     dbgs() << "\n";
529   });
530 
531   return Result.isValid();
532 }
533 
534 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
535                                                 const SCEV *Expr,
536                                                 ScalarEvolution &SE,
537                                                 const Value *BaseAddress) {
538   if (isa<SCEVCouldNotCompute>(Expr))
539     return std::vector<const SCEV *>();
540 
541   SCEVValidator Validator(R, SE, BaseAddress);
542   ValidatorResult Result = Validator.visit(Expr);
543 
544   return Result.getParameters();
545 }
546 }
547