xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 58032cb029dcb366be868db48c89af128cce1dd4)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 
5 #define DEBUG_TYPE "polly-scev-validator"
6 #include "llvm/Support/Debug.h"
7 #include "llvm/Analysis/ScalarEvolution.h"
8 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
9 #include "llvm/Analysis/RegionInfo.h"
10 
11 #include <vector>
12 
13 using namespace llvm;
14 
15 namespace SCEVType {
16 /// @brief The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23   // An integer value.
24   INT,
25 
26   // An expression that is constant during the execution of the Scop,
27   // but that may depend on parameters unknown at compile time.
28   PARAM,
29 
30   // An expression that may change during the execution of the SCoP.
31   IV,
32 
33   // An invalid expression.
34   INVALID
35 };
36 }
37 
38 /// @brief The result the validator returns for a SCEV expression.
39 class ValidatorResult {
40   /// @brief The type of the expression
41   SCEVType::TYPE Type;
42 
43   /// @brief The set of Parameters in the expression.
44   std::vector<const SCEV *> Parameters;
45 
46 public:
47   /// @brief The copy constructor
48   ValidatorResult(const ValidatorResult &Source) {
49     Type = Source.Type;
50     Parameters = Source.Parameters;
51   }
52 
53   /// @brief Construct a result with a certain type and no parameters.
54   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
56   }
57 
58   /// @brief Construct a result with a certain type and a single parameter.
59   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60     Parameters.push_back(Expr);
61   }
62 
63   /// @brief Get the type of the ValidatorResult.
64   SCEVType::TYPE getType() { return Type; }
65 
66   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
67   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() { return Type != SCEVType::INVALID; }
71 
72   /// @brief Is the analyzed SCEV of Type IV.
73   bool isIV() { return Type == SCEVType::IV; }
74 
75   /// @brief Is the analyzed SCEV of Type INT.
76   bool isINT() { return Type == SCEVType::INT; }
77 
78   /// @brief Is the analyzed SCEV of Type PARAM.
79   bool isPARAM() { return Type == SCEVType::PARAM; }
80 
81   /// @brief Get the parameters of this validator result.
82   std::vector<const SCEV *> getParameters() { return Parameters; }
83 
84   /// @brief Add the parameters of Source to this result.
85   void addParamsFrom(class ValidatorResult &Source) {
86     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
87                       Source.Parameters.end());
88   }
89 
90   /// @brief Merge a result.
91   ///
92   /// This means to merge the parameters and to set the Type to the most
93   /// specific Type that matches both.
94   void merge(class ValidatorResult &ToMerge) {
95     Type = std::max(Type, ToMerge.Type);
96     addParamsFrom(ToMerge);
97   }
98 
99   void print(raw_ostream &OS) {
100     switch (Type) {
101     case SCEVType::INT:
102       OS << "SCEVType::INT";
103       break;
104     case SCEVType::PARAM:
105       OS << "SCEVType::PARAM";
106       break;
107     case SCEVType::IV:
108       OS << "SCEVType::IV";
109       break;
110     case SCEVType::INVALID:
111       OS << "SCEVType::INVALID";
112       break;
113     }
114   }
115 };
116 
117 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
118   VR.print(OS);
119   return OS;
120 }
121 
122 /// Check if a SCEV is valid in a SCoP.
123 struct SCEVValidator
124     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
125 private:
126   const Region *R;
127   ScalarEvolution &SE;
128   const Value *BaseAddress;
129 
130 public:
131   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
132       : R(R), SE(SE), BaseAddress(BaseAddress) {}
133 
134   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135     return ValidatorResult(SCEVType::INT);
136   }
137 
138   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
139     ValidatorResult Op = visit(Expr->getOperand());
140 
141     switch (Op.getType()) {
142     case SCEVType::INT:
143     case SCEVType::PARAM:
144       // We currently do not represent a truncate expression as an affine
145       // expression. If it is constant during Scop execution, we treat it as a
146       // parameter.
147       return ValidatorResult(SCEVType::PARAM, Expr);
148     case SCEVType::IV:
149       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
150       return ValidatorResult(SCEVType::INVALID);
151     case SCEVType::INVALID:
152       return Op;
153     }
154 
155     llvm_unreachable("Unknown SCEVType");
156   }
157 
158   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
159     ValidatorResult Op = visit(Expr->getOperand());
160 
161     switch (Op.getType()) {
162     case SCEVType::INT:
163     case SCEVType::PARAM:
164       // We currently do not represent a truncate expression as an affine
165       // expression. If it is constant during Scop execution, we treat it as a
166       // parameter.
167       return ValidatorResult(SCEVType::PARAM, Expr);
168     case SCEVType::IV:
169       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
170       return ValidatorResult(SCEVType::INVALID);
171     case SCEVType::INVALID:
172       return Op;
173     }
174 
175     llvm_unreachable("Unknown SCEVType");
176   }
177 
178   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
179     // We currently allow only signed SCEV expressions. In the case of a
180     // signed value, a sign extend is a noop.
181     //
182     // TODO: Reconsider this when we add support for unsigned values.
183     return visit(Expr->getOperand());
184   }
185 
186   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
187     ValidatorResult Return(SCEVType::INT);
188 
189     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
190       ValidatorResult Op = visit(Expr->getOperand(i));
191       Return.merge(Op);
192 
193       // Early exit.
194       if (!Return.isValid())
195         break;
196     }
197 
198     // TODO: Check for NSW and NUW.
199     return Return;
200   }
201 
202   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
203     ValidatorResult Return(SCEVType::INT);
204 
205     bool HasMultipleParams = false;
206 
207     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
208       ValidatorResult Op = visit(Expr->getOperand(i));
209 
210       if (Op.isINT())
211         continue;
212 
213       if (Op.isPARAM() && Return.isPARAM()) {
214         HasMultipleParams = true;
215         continue;
216       }
217 
218       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
219         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
220                      << "\tExpr: " << *Expr << "\n"
221                      << "\tPrevious expression type: " << Return << "\n"
222                      << "\tNext operand (" << Op
223                      << "): " << *Expr->getOperand(i) << "\n");
224 
225         return ValidatorResult(SCEVType::INVALID);
226       }
227 
228       Return.merge(Op);
229     }
230 
231     if (HasMultipleParams)
232       return ValidatorResult(SCEVType::PARAM, Expr);
233 
234     // TODO: Check for NSW and NUW.
235     return Return;
236   }
237 
238   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
239     ValidatorResult LHS = visit(Expr->getLHS());
240     ValidatorResult RHS = visit(Expr->getRHS());
241 
242     // We currently do not represent an unsigned division as an affine
243     // expression. If the division is constant during Scop execution we treat it
244     // as a parameter, otherwise we bail out.
245     if (LHS.isConstant() && RHS.isConstant())
246       return ValidatorResult(SCEVType::PARAM, Expr);
247 
248     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
249     return ValidatorResult(SCEVType::INVALID);
250   }
251 
252   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
253     if (!Expr->isAffine()) {
254       DEBUG(dbgs() << "INVALID: AddRec is not affine");
255       return ValidatorResult(SCEVType::INVALID);
256     }
257 
258     ValidatorResult Start = visit(Expr->getStart());
259     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
260 
261     if (!Start.isValid())
262       return Start;
263 
264     if (!Recurrence.isValid())
265       return Recurrence;
266 
267     if (R->contains(Expr->getLoop())) {
268       if (Recurrence.isINT()) {
269         ValidatorResult Result(SCEVType::IV);
270         Result.addParamsFrom(Start);
271         return Result;
272       }
273 
274       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
275                       "recurrence part");
276       return ValidatorResult(SCEVType::INVALID);
277     }
278 
279     assert(Start.isConstant() && Recurrence.isConstant() &&
280            "Expected 'Start' and 'Recurrence' to be constant");
281     return ValidatorResult(SCEVType::PARAM, Expr);
282   }
283 
284   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
285     ValidatorResult Return(SCEVType::INT, Expr);
286 
287     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
288       ValidatorResult Op = visit(Expr->getOperand(i));
289 
290       if (!Op.isValid())
291         return Op;
292 
293       Return.merge(Op);
294     }
295 
296     return Return;
297   }
298 
299   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
300     // We do not support unsigned operations. If 'Expr' is constant during Scop
301     // execution we treat this as a parameter, otherwise we bail out.
302     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
303       ValidatorResult Op = visit(Expr->getOperand(i));
304 
305       if (!Op.isConstant()) {
306         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
307         return ValidatorResult(SCEVType::INVALID);
308       }
309     }
310 
311     return ValidatorResult(SCEVType::PARAM, Expr);
312   }
313 
314   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
315     Value *V = Expr->getValue();
316 
317     // We currently only support integer types. It may be useful to support
318     // pointer types, e.g. to support code like:
319     //
320     //   if (A)
321     //     A[i] = 1;
322     //
323     // See test/CodeGen/20120316-InvalidCast.ll
324     if (!Expr->getType()->isIntegerTy()) {
325       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer type");
326       return ValidatorResult(SCEVType::INVALID);
327     }
328 
329     if (isa<UndefValue>(V)) {
330       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
331       return ValidatorResult(SCEVType::INVALID);
332     }
333 
334     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
335       if (R->contains(I)) {
336         DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
337                         "within the region\n");
338         return ValidatorResult(SCEVType::INVALID);
339       }
340 
341     if (BaseAddress == V) {
342       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
343       return ValidatorResult(SCEVType::INVALID);
344     }
345 
346     return ValidatorResult(SCEVType::PARAM, Expr);
347   }
348 };
349 
350 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
351 ///
352 struct SCEVInRegionDependences
353     : public SCEVVisitor<SCEVInRegionDependences, bool> {
354 public:
355 
356   /// Returns true when the SCEV has SSA names defined in region R.
357   static bool hasDependences(const SCEV *S, const Region *R) {
358     SCEVInRegionDependences Ignore(R);
359     return Ignore.visit(S);
360   }
361 
362   SCEVInRegionDependences(const Region *R) : R(R) {}
363 
364   bool visit(const SCEV *Expr) {
365     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
366   }
367 
368   bool visitConstant(const SCEVConstant *Constant) { return false; }
369 
370   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
371     return visit(Expr->getOperand());
372   }
373 
374   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
375     return visit(Expr->getOperand());
376   }
377 
378   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
379     return visit(Expr->getOperand());
380   }
381 
382   bool visitAddExpr(const SCEVAddExpr *Expr) {
383     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
384       if (visit(Expr->getOperand(i)))
385         return true;
386 
387     return false;
388   }
389 
390   bool visitMulExpr(const SCEVMulExpr *Expr) {
391     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
392       if (visit(Expr->getOperand(i)))
393         return true;
394 
395     return false;
396   }
397 
398   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
399     if (visit(Expr->getLHS()))
400       return true;
401 
402     if (visit(Expr->getRHS()))
403       return true;
404 
405     return false;
406   }
407 
408   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
409     if (visit(Expr->getStart()))
410       return true;
411 
412     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
413       if (visit(Expr->getOperand(i)))
414         return true;
415 
416     return false;
417   }
418 
419   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
420     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
421       if (visit(Expr->getOperand(i)))
422         return true;
423 
424     return false;
425   }
426 
427   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
428     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
429       if (visit(Expr->getOperand(i)))
430         return true;
431 
432     return false;
433   }
434 
435   bool visitUnknown(const SCEVUnknown *Expr) {
436     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
437 
438     // Return true when Inst is defined inside the region R.
439     if (Inst && R->contains(Inst))
440       return true;
441 
442     return false;
443   }
444 
445 private:
446   const Region *R;
447 };
448 
449 namespace polly {
450 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
451   return SCEVInRegionDependences::hasDependences(Expr, R);
452 }
453 
454 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
455                   const Value *BaseAddress) {
456   if (isa<SCEVCouldNotCompute>(Expr))
457     return false;
458 
459   SCEVValidator Validator(R, SE, BaseAddress);
460   DEBUG(dbgs() << "\n"; dbgs() << "Expr: " << *Expr << "\n";
461         dbgs() << "Region: " << R->getNameStr() << "\n"; dbgs() << " -> ");
462 
463   ValidatorResult Result = Validator.visit(Expr);
464 
465   DEBUG(if (Result.isValid()) dbgs() << "VALID\n"; dbgs() << "\n";);
466 
467   return Result.isValid();
468 }
469 
470 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
471                                                 const SCEV *Expr,
472                                                 ScalarEvolution &SE,
473                                                 const Value *BaseAddress) {
474   if (isa<SCEVCouldNotCompute>(Expr))
475     return std::vector<const SCEV *>();
476 
477   SCEVValidator Validator(R, SE, BaseAddress);
478   ValidatorResult Result = Validator.visit(Expr);
479 
480   return Result.getParameters();
481 }
482 }
483