xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision ba0d09227c66b5e7731d02b24c2f07014587ca0f)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 #include <vector>
9 
10 using namespace llvm;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// @brief The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 }
36 
37 /// @brief The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// @brief The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// @brief The set of Parameters in the expression.
43   std::vector<const SCEV *> Parameters;
44 
45 public:
46   /// @brief The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// @brief Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// @brief Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.push_back(Expr);
60   }
61 
62   /// @brief Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// @brief Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// @brief Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// @brief Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// @brief Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// @brief Get the parameters of this validator result.
81   std::vector<const SCEV *> getParameters() { return Parameters; }
82 
83   /// @brief Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
86                       Source.Parameters.end());
87   }
88 
89   /// @brief Merge a result.
90   ///
91   /// This means to merge the parameters and to set the Type to the most
92   /// specific Type that matches both.
93   void merge(const ValidatorResult &ToMerge) {
94     Type = std::max(Type, ToMerge.Type);
95     addParamsFrom(ToMerge);
96   }
97 
98   void print(raw_ostream &OS) {
99     switch (Type) {
100     case SCEVType::INT:
101       OS << "SCEVType::INT";
102       break;
103     case SCEVType::PARAM:
104       OS << "SCEVType::PARAM";
105       break;
106     case SCEVType::IV:
107       OS << "SCEVType::IV";
108       break;
109     case SCEVType::INVALID:
110       OS << "SCEVType::INVALID";
111       break;
112     }
113   }
114 };
115 
116 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
117   VR.print(OS);
118   return OS;
119 }
120 
121 /// Check if a SCEV is valid in a SCoP.
122 struct SCEVValidator
123     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
124 private:
125   const Region *R;
126   ScalarEvolution &SE;
127   const Value *BaseAddress;
128 
129 public:
130   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
131       : R(R), SE(SE), BaseAddress(BaseAddress) {}
132 
133   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
134     return ValidatorResult(SCEVType::INT);
135   }
136 
137   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
138     ValidatorResult Op = visit(Expr->getOperand());
139 
140     switch (Op.getType()) {
141     case SCEVType::INT:
142     case SCEVType::PARAM:
143       // We currently do not represent a truncate expression as an affine
144       // expression. If it is constant during Scop execution, we treat it as a
145       // parameter.
146       return ValidatorResult(SCEVType::PARAM, Expr);
147     case SCEVType::IV:
148       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
149       return ValidatorResult(SCEVType::INVALID);
150     case SCEVType::INVALID:
151       return Op;
152     }
153 
154     llvm_unreachable("Unknown SCEVType");
155   }
156 
157   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
158     ValidatorResult Op = visit(Expr->getOperand());
159 
160     switch (Op.getType()) {
161     case SCEVType::INT:
162     case SCEVType::PARAM:
163       // We currently do not represent a truncate expression as an affine
164       // expression. If it is constant during Scop execution, we treat it as a
165       // parameter.
166       return ValidatorResult(SCEVType::PARAM, Expr);
167     case SCEVType::IV:
168       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
169       return ValidatorResult(SCEVType::INVALID);
170     case SCEVType::INVALID:
171       return Op;
172     }
173 
174     llvm_unreachable("Unknown SCEVType");
175   }
176 
177   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
178     // We currently allow only signed SCEV expressions. In the case of a
179     // signed value, a sign extend is a noop.
180     //
181     // TODO: Reconsider this when we add support for unsigned values.
182     return visit(Expr->getOperand());
183   }
184 
185   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
186     ValidatorResult Return(SCEVType::INT);
187 
188     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
189       ValidatorResult Op = visit(Expr->getOperand(i));
190       Return.merge(Op);
191 
192       // Early exit.
193       if (!Return.isValid())
194         break;
195     }
196 
197     // TODO: Check for NSW and NUW.
198     return Return;
199   }
200 
201   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
202     ValidatorResult Return(SCEVType::INT);
203 
204     bool HasMultipleParams = false;
205 
206     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
207       ValidatorResult Op = visit(Expr->getOperand(i));
208 
209       if (Op.isINT())
210         continue;
211 
212       if (Op.isPARAM() && Return.isPARAM()) {
213         HasMultipleParams = true;
214         continue;
215       }
216 
217       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
218         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
219                      << "\tExpr: " << *Expr << "\n"
220                      << "\tPrevious expression type: " << Return << "\n"
221                      << "\tNext operand (" << Op
222                      << "): " << *Expr->getOperand(i) << "\n");
223 
224         return ValidatorResult(SCEVType::INVALID);
225       }
226 
227       Return.merge(Op);
228     }
229 
230     if (HasMultipleParams && Return.isValid())
231       return ValidatorResult(SCEVType::PARAM, Expr);
232 
233     // TODO: Check for NSW and NUW.
234     return Return;
235   }
236 
237   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
238     ValidatorResult LHS = visit(Expr->getLHS());
239     ValidatorResult RHS = visit(Expr->getRHS());
240 
241     // We currently do not represent an unsigned division as an affine
242     // expression. If the division is constant during Scop execution we treat it
243     // as a parameter, otherwise we bail out.
244     if (LHS.isConstant() && RHS.isConstant())
245       return ValidatorResult(SCEVType::PARAM, Expr);
246 
247     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
248     return ValidatorResult(SCEVType::INVALID);
249   }
250 
251   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
252     if (!Expr->isAffine()) {
253       DEBUG(dbgs() << "INVALID: AddRec is not affine");
254       return ValidatorResult(SCEVType::INVALID);
255     }
256 
257     ValidatorResult Start = visit(Expr->getStart());
258     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
259 
260     if (!Start.isValid())
261       return Start;
262 
263     if (!Recurrence.isValid())
264       return Recurrence;
265 
266     if (R->contains(Expr->getLoop())) {
267       if (Recurrence.isINT()) {
268         ValidatorResult Result(SCEVType::IV);
269         Result.addParamsFrom(Start);
270         return Result;
271       }
272 
273       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
274                       "recurrence part");
275       return ValidatorResult(SCEVType::INVALID);
276     }
277 
278     assert(Start.isConstant() && Recurrence.isConstant() &&
279            "Expected 'Start' and 'Recurrence' to be constant");
280 
281     // Directly generate ValidatorResult for Expr if 'start' is zero.
282     if (Expr->getStart()->isZero())
283       return ValidatorResult(SCEVType::PARAM, Expr);
284 
285     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
286     // if 'start' is not zero.
287     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
288         SE.getConstant(Expr->getStart()->getType(), 0),
289         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
290 
291     ValidatorResult ZeroStartResult =
292         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
293     ZeroStartResult.addParamsFrom(Start);
294 
295     return ZeroStartResult;
296   }
297 
298   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
299     ValidatorResult Return(SCEVType::INT);
300 
301     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
302       ValidatorResult Op = visit(Expr->getOperand(i));
303 
304       if (!Op.isValid())
305         return Op;
306 
307       Return.merge(Op);
308     }
309 
310     return Return;
311   }
312 
313   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
314     // We do not support unsigned operations. If 'Expr' is constant during Scop
315     // execution we treat this as a parameter, otherwise we bail out.
316     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
317       ValidatorResult Op = visit(Expr->getOperand(i));
318 
319       if (!Op.isConstant()) {
320         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
321         return ValidatorResult(SCEVType::INVALID);
322       }
323     }
324 
325     return ValidatorResult(SCEVType::PARAM, Expr);
326   }
327 
328   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
329     if (R->contains(I)) {
330       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
331                       "within the region\n");
332       return ValidatorResult(SCEVType::INVALID);
333     }
334 
335     return ValidatorResult(SCEVType::PARAM, S);
336   }
337 
338   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) {
339     assert(SDiv->getOpcode() == Instruction::SDiv &&
340            "Assumed SDiv instruction!");
341 
342     auto *Divisor = SDiv->getOperand(1);
343     auto *CI = dyn_cast<ConstantInt>(Divisor);
344     if (!CI)
345       return visitGenericInst(SDiv, S);
346 
347     auto *Dividend = SDiv->getOperand(0);
348     auto *DividendSCEV = SE.getSCEV(Dividend);
349     return visit(DividendSCEV);
350   }
351 
352   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
353     Value *V = Expr->getValue();
354 
355     if (!(Expr->getType()->isIntegerTy() || Expr->getType()->isPointerTy())) {
356       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer type");
357       return ValidatorResult(SCEVType::INVALID);
358     }
359 
360     if (isa<UndefValue>(V)) {
361       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
362       return ValidatorResult(SCEVType::INVALID);
363     }
364 
365     if (BaseAddress == V) {
366       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
367       return ValidatorResult(SCEVType::INVALID);
368     }
369 
370     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
371       switch (I->getOpcode()) {
372       case Instruction::SDiv:
373         return visitSDivInstruction(I, Expr);
374       default:
375         return visitGenericInst(I, Expr);
376       }
377     }
378 
379     return ValidatorResult(SCEVType::PARAM, Expr);
380   }
381 };
382 
383 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
384 ///
385 struct SCEVInRegionDependences
386     : public SCEVVisitor<SCEVInRegionDependences, bool> {
387 public:
388   /// Returns true when the SCEV has SSA names defined in region R.
389   static bool hasDependences(const SCEV *S, const Region *R) {
390     SCEVInRegionDependences Ignore(R);
391     return Ignore.visit(S);
392   }
393 
394   SCEVInRegionDependences(const Region *R) : R(R) {}
395 
396   bool visit(const SCEV *Expr) {
397     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
398   }
399 
400   bool visitConstant(const SCEVConstant *Constant) { return false; }
401 
402   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
403     return visit(Expr->getOperand());
404   }
405 
406   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
407     return visit(Expr->getOperand());
408   }
409 
410   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
411     return visit(Expr->getOperand());
412   }
413 
414   bool visitAddExpr(const SCEVAddExpr *Expr) {
415     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
416       if (visit(Expr->getOperand(i)))
417         return true;
418 
419     return false;
420   }
421 
422   bool visitMulExpr(const SCEVMulExpr *Expr) {
423     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
424       if (visit(Expr->getOperand(i)))
425         return true;
426 
427     return false;
428   }
429 
430   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
431     if (visit(Expr->getLHS()))
432       return true;
433 
434     if (visit(Expr->getRHS()))
435       return true;
436 
437     return false;
438   }
439 
440   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
441     if (visit(Expr->getStart()))
442       return true;
443 
444     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
445       if (visit(Expr->getOperand(i)))
446         return true;
447 
448     return false;
449   }
450 
451   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
452     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
453       if (visit(Expr->getOperand(i)))
454         return true;
455 
456     return false;
457   }
458 
459   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
460     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
461       if (visit(Expr->getOperand(i)))
462         return true;
463 
464     return false;
465   }
466 
467   bool visitUnknown(const SCEVUnknown *Expr) {
468     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
469 
470     // Return true when Inst is defined inside the region R.
471     if (Inst && R->contains(Inst))
472       return true;
473 
474     return false;
475   }
476 
477 private:
478   const Region *R;
479 };
480 
481 namespace polly {
482 /// Find all loops referenced in SCEVAddRecExprs.
483 class SCEVFindLoops {
484   SetVector<const Loop *> &Loops;
485 
486 public:
487   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
488 
489   bool follow(const SCEV *S) {
490     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
491       Loops.insert(AddRec->getLoop());
492     return true;
493   }
494   bool isDone() { return false; }
495 };
496 
497 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
498   SCEVFindLoops FindLoops(Loops);
499   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
500   ST.visitAll(Expr);
501 }
502 
503 /// Find all values referenced in SCEVUnknowns.
504 class SCEVFindValues {
505   SetVector<Value *> &Values;
506 
507 public:
508   SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {}
509 
510   bool follow(const SCEV *S) {
511     if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S))
512       Values.insert(Unknown->getValue());
513     return true;
514   }
515   bool isDone() { return false; }
516 };
517 
518 void findValues(const SCEV *Expr, SetVector<Value *> &Values) {
519   SCEVFindValues FindValues(Values);
520   SCEVTraversal<SCEVFindValues> ST(FindValues);
521   ST.visitAll(Expr);
522 }
523 
524 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
525   return SCEVInRegionDependences::hasDependences(Expr, R);
526 }
527 
528 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
529                   const Value *BaseAddress) {
530   if (isa<SCEVCouldNotCompute>(Expr))
531     return false;
532 
533   SCEVValidator Validator(R, SE, BaseAddress);
534   DEBUG({
535     dbgs() << "\n";
536     dbgs() << "Expr: " << *Expr << "\n";
537     dbgs() << "Region: " << R->getNameStr() << "\n";
538     dbgs() << " -> ";
539   });
540 
541   ValidatorResult Result = Validator.visit(Expr);
542 
543   DEBUG({
544     if (Result.isValid())
545       dbgs() << "VALID\n";
546     dbgs() << "\n";
547   });
548 
549   return Result.isValid();
550 }
551 
552 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
553                                                 const SCEV *Expr,
554                                                 ScalarEvolution &SE,
555                                                 const Value *BaseAddress) {
556   if (isa<SCEVCouldNotCompute>(Expr))
557     return std::vector<const SCEV *>();
558 
559   SCEVValidator Validator(R, SE, BaseAddress);
560   ValidatorResult Result = Validator.visit(Expr);
561   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
562 
563   return Result.getParameters();
564 }
565 
566 std::pair<const SCEV *, const SCEV *>
567 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
568 
569   const SCEV *LeftOver = SE.getConstant(S->getType(), 1);
570   const SCEV *ConstPart = SE.getConstant(S->getType(), 1);
571 
572   const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S);
573   if (!M)
574     return std::make_pair(ConstPart, S);
575 
576   for (const SCEV *Op : M->operands())
577     if (isa<SCEVConstant>(Op))
578       ConstPart = SE.getMulExpr(ConstPart, Op);
579     else
580       LeftOver = SE.getMulExpr(LeftOver, Op);
581 
582   return std::make_pair(ConstPart, LeftOver);
583 }
584 }
585