xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 7ca8dc2d2df159c7f292b2e50fee7d2957986558)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 #include <vector>
9 
10 using namespace llvm;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// @brief The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 }
36 
37 /// @brief The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// @brief The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// @brief The set of Parameters in the expression.
43   std::vector<const SCEV *> Parameters;
44 
45 public:
46   /// @brief The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// @brief Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// @brief Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.push_back(Expr);
60   }
61 
62   /// @brief Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// @brief Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// @brief Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// @brief Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// @brief Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// @brief Get the parameters of this validator result.
81   std::vector<const SCEV *> getParameters() { return Parameters; }
82 
83   /// @brief Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
86                       Source.Parameters.end());
87   }
88 
89   /// @brief Merge a result.
90   ///
91   /// This means to merge the parameters and to set the Type to the most
92   /// specific Type that matches both.
93   void merge(const ValidatorResult &ToMerge) {
94     Type = std::max(Type, ToMerge.Type);
95     addParamsFrom(ToMerge);
96   }
97 
98   void print(raw_ostream &OS) {
99     switch (Type) {
100     case SCEVType::INT:
101       OS << "SCEVType::INT";
102       break;
103     case SCEVType::PARAM:
104       OS << "SCEVType::PARAM";
105       break;
106     case SCEVType::IV:
107       OS << "SCEVType::IV";
108       break;
109     case SCEVType::INVALID:
110       OS << "SCEVType::INVALID";
111       break;
112     }
113   }
114 };
115 
116 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
117   VR.print(OS);
118   return OS;
119 }
120 
121 /// Check if a SCEV is valid in a SCoP.
122 struct SCEVValidator
123     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
124 private:
125   const Region *R;
126   ScalarEvolution &SE;
127   const Value *BaseAddress;
128 
129 public:
130   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
131       : R(R), SE(SE), BaseAddress(BaseAddress) {}
132 
133   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
134     return ValidatorResult(SCEVType::INT);
135   }
136 
137   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
138     ValidatorResult Op = visit(Expr->getOperand());
139 
140     switch (Op.getType()) {
141     case SCEVType::INT:
142     case SCEVType::PARAM:
143       // We currently do not represent a truncate expression as an affine
144       // expression. If it is constant during Scop execution, we treat it as a
145       // parameter.
146       return ValidatorResult(SCEVType::PARAM, Expr);
147     case SCEVType::IV:
148       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
149       return ValidatorResult(SCEVType::INVALID);
150     case SCEVType::INVALID:
151       return Op;
152     }
153 
154     llvm_unreachable("Unknown SCEVType");
155   }
156 
157   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
158     ValidatorResult Op = visit(Expr->getOperand());
159 
160     switch (Op.getType()) {
161     case SCEVType::INT:
162     case SCEVType::PARAM:
163       // We currently do not represent a truncate expression as an affine
164       // expression. If it is constant during Scop execution, we treat it as a
165       // parameter.
166       return ValidatorResult(SCEVType::PARAM, Expr);
167     case SCEVType::IV:
168       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
169       return ValidatorResult(SCEVType::INVALID);
170     case SCEVType::INVALID:
171       return Op;
172     }
173 
174     llvm_unreachable("Unknown SCEVType");
175   }
176 
177   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
178     // We currently allow only signed SCEV expressions. In the case of a
179     // signed value, a sign extend is a noop.
180     //
181     // TODO: Reconsider this when we add support for unsigned values.
182     return visit(Expr->getOperand());
183   }
184 
185   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
186     ValidatorResult Return(SCEVType::INT);
187 
188     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
189       ValidatorResult Op = visit(Expr->getOperand(i));
190       Return.merge(Op);
191 
192       // Early exit.
193       if (!Return.isValid())
194         break;
195     }
196 
197     // TODO: Check for NSW and NUW.
198     return Return;
199   }
200 
201   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
202     ValidatorResult Return(SCEVType::INT);
203 
204     bool HasMultipleParams = false;
205 
206     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
207       ValidatorResult Op = visit(Expr->getOperand(i));
208 
209       if (Op.isINT())
210         continue;
211 
212       if (Op.isPARAM() && Return.isPARAM()) {
213         HasMultipleParams = true;
214         continue;
215       }
216 
217       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
218         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
219                      << "\tExpr: " << *Expr << "\n"
220                      << "\tPrevious expression type: " << Return << "\n"
221                      << "\tNext operand (" << Op
222                      << "): " << *Expr->getOperand(i) << "\n");
223 
224         return ValidatorResult(SCEVType::INVALID);
225       }
226 
227       Return.merge(Op);
228     }
229 
230     if (HasMultipleParams && Return.isValid())
231       return ValidatorResult(SCEVType::PARAM, Expr);
232 
233     // TODO: Check for NSW and NUW.
234     return Return;
235   }
236 
237   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
238     ValidatorResult LHS = visit(Expr->getLHS());
239     ValidatorResult RHS = visit(Expr->getRHS());
240 
241     // We currently do not represent an unsigned division as an affine
242     // expression. If the division is constant during Scop execution we treat it
243     // as a parameter, otherwise we bail out.
244     if (LHS.isConstant() && RHS.isConstant())
245       return ValidatorResult(SCEVType::PARAM, Expr);
246 
247     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
248     return ValidatorResult(SCEVType::INVALID);
249   }
250 
251   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
252     if (!Expr->isAffine()) {
253       DEBUG(dbgs() << "INVALID: AddRec is not affine");
254       return ValidatorResult(SCEVType::INVALID);
255     }
256 
257     ValidatorResult Start = visit(Expr->getStart());
258     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
259 
260     if (!Start.isValid())
261       return Start;
262 
263     if (!Recurrence.isValid())
264       return Recurrence;
265 
266     if (R->contains(Expr->getLoop())) {
267       if (Recurrence.isINT()) {
268         ValidatorResult Result(SCEVType::IV);
269         Result.addParamsFrom(Start);
270         return Result;
271       }
272 
273       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
274                       "recurrence part");
275       return ValidatorResult(SCEVType::INVALID);
276     }
277 
278     assert(Start.isConstant() && Recurrence.isConstant() &&
279            "Expected 'Start' and 'Recurrence' to be constant");
280 
281     // Directly generate ValidatorResult for Expr if 'start' is zero.
282     if (Expr->getStart()->isZero())
283       return ValidatorResult(SCEVType::PARAM, Expr);
284 
285     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
286     // if 'start' is not zero.
287     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
288         SE.getConstant(Expr->getStart()->getType(), 0),
289         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
290 
291     ValidatorResult ZeroStartResult =
292         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
293     ZeroStartResult.addParamsFrom(Start);
294 
295     return ZeroStartResult;
296   }
297 
298   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
299     ValidatorResult Return(SCEVType::INT);
300 
301     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
302       ValidatorResult Op = visit(Expr->getOperand(i));
303 
304       if (!Op.isValid())
305         return Op;
306 
307       Return.merge(Op);
308     }
309 
310     return Return;
311   }
312 
313   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
314     // We do not support unsigned operations. If 'Expr' is constant during Scop
315     // execution we treat this as a parameter, otherwise we bail out.
316     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
317       ValidatorResult Op = visit(Expr->getOperand(i));
318 
319       if (!Op.isConstant()) {
320         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
321         return ValidatorResult(SCEVType::INVALID);
322       }
323     }
324 
325     return ValidatorResult(SCEVType::PARAM, Expr);
326   }
327 
328   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
329     if (R->contains(I)) {
330       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
331                       "within the region\n");
332       return ValidatorResult(SCEVType::INVALID);
333     }
334 
335     return ValidatorResult(SCEVType::PARAM, S);
336   }
337 
338   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) {
339     assert(SDiv->getOpcode() == Instruction::SDiv &&
340            "Assumed SDiv instruction!");
341 
342     auto *Divisor = SDiv->getOperand(1);
343     auto *CI = dyn_cast<ConstantInt>(Divisor);
344     if (!CI)
345       return visitGenericInst(SDiv, S);
346 
347     auto *Dividend = SDiv->getOperand(0);
348     auto *DividendSCEV = SE.getSCEV(Dividend);
349     return visit(DividendSCEV);
350   }
351 
352   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
353     assert(SRem->getOpcode() == Instruction::SRem &&
354            "Assumed SRem instruction!");
355 
356     auto *Divisor = SRem->getOperand(1);
357     auto *CI = dyn_cast<ConstantInt>(Divisor);
358     if (!CI)
359       return visitGenericInst(SRem, S);
360 
361     auto *Dividend = SRem->getOperand(0);
362     auto *DividendSCEV = SE.getSCEV(Dividend);
363     return visit(DividendSCEV);
364   }
365 
366   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
367     Value *V = Expr->getValue();
368 
369     // TODO: FIXME: IslExprBuilder is not capable of producing valid code
370     //              for arbitrary pointer expressions at the moment. Until
371     //              this is fixed we disallow pointer expressions completely.
372     if (Expr->getType()->isPointerTy()) {
373       DEBUG(dbgs() << "INVALID: UnknownExpr is a pointer type [FIXME]");
374       return ValidatorResult(SCEVType::INVALID);
375     }
376 
377     if (!Expr->getType()->isIntegerTy()) {
378       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer");
379       return ValidatorResult(SCEVType::INVALID);
380     }
381 
382     if (isa<UndefValue>(V)) {
383       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
384       return ValidatorResult(SCEVType::INVALID);
385     }
386 
387     if (BaseAddress == V) {
388       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
389       return ValidatorResult(SCEVType::INVALID);
390     }
391 
392     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
393       switch (I->getOpcode()) {
394       case Instruction::SDiv:
395         return visitSDivInstruction(I, Expr);
396       case Instruction::SRem:
397         return visitSRemInstruction(I, Expr);
398       default:
399         return visitGenericInst(I, Expr);
400       }
401     }
402 
403     return ValidatorResult(SCEVType::PARAM, Expr);
404   }
405 };
406 
407 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
408 ///
409 struct SCEVInRegionDependences
410     : public SCEVVisitor<SCEVInRegionDependences, bool> {
411 public:
412   /// Returns true when the SCEV has SSA names defined in region R.
413   static bool hasDependences(const SCEV *S, const Region *R) {
414     SCEVInRegionDependences Ignore(R);
415     return Ignore.visit(S);
416   }
417 
418   SCEVInRegionDependences(const Region *R) : R(R) {}
419 
420   bool visit(const SCEV *Expr) {
421     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
422   }
423 
424   bool visitConstant(const SCEVConstant *Constant) { return false; }
425 
426   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
427     return visit(Expr->getOperand());
428   }
429 
430   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
431     return visit(Expr->getOperand());
432   }
433 
434   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
435     return visit(Expr->getOperand());
436   }
437 
438   bool visitAddExpr(const SCEVAddExpr *Expr) {
439     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
440       if (visit(Expr->getOperand(i)))
441         return true;
442 
443     return false;
444   }
445 
446   bool visitMulExpr(const SCEVMulExpr *Expr) {
447     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
448       if (visit(Expr->getOperand(i)))
449         return true;
450 
451     return false;
452   }
453 
454   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
455     if (visit(Expr->getLHS()))
456       return true;
457 
458     if (visit(Expr->getRHS()))
459       return true;
460 
461     return false;
462   }
463 
464   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
465     if (visit(Expr->getStart()))
466       return true;
467 
468     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
469       if (visit(Expr->getOperand(i)))
470         return true;
471 
472     return false;
473   }
474 
475   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
476     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
477       if (visit(Expr->getOperand(i)))
478         return true;
479 
480     return false;
481   }
482 
483   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
484     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
485       if (visit(Expr->getOperand(i)))
486         return true;
487 
488     return false;
489   }
490 
491   bool visitUnknown(const SCEVUnknown *Expr) {
492     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
493 
494     // Return true when Inst is defined inside the region R.
495     if (Inst && R->contains(Inst))
496       return true;
497 
498     return false;
499   }
500 
501 private:
502   const Region *R;
503 };
504 
505 namespace polly {
506 /// Find all loops referenced in SCEVAddRecExprs.
507 class SCEVFindLoops {
508   SetVector<const Loop *> &Loops;
509 
510 public:
511   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
512 
513   bool follow(const SCEV *S) {
514     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
515       Loops.insert(AddRec->getLoop());
516     return true;
517   }
518   bool isDone() { return false; }
519 };
520 
521 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
522   SCEVFindLoops FindLoops(Loops);
523   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
524   ST.visitAll(Expr);
525 }
526 
527 /// Find all values referenced in SCEVUnknowns.
528 class SCEVFindValues {
529   SetVector<Value *> &Values;
530 
531 public:
532   SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {}
533 
534   bool follow(const SCEV *S) {
535     if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S))
536       Values.insert(Unknown->getValue());
537     return true;
538   }
539   bool isDone() { return false; }
540 };
541 
542 void findValues(const SCEV *Expr, SetVector<Value *> &Values) {
543   SCEVFindValues FindValues(Values);
544   SCEVTraversal<SCEVFindValues> ST(FindValues);
545   ST.visitAll(Expr);
546 }
547 
548 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
549   return SCEVInRegionDependences::hasDependences(Expr, R);
550 }
551 
552 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
553                   const Value *BaseAddress) {
554   if (isa<SCEVCouldNotCompute>(Expr))
555     return false;
556 
557   SCEVValidator Validator(R, SE, BaseAddress);
558   DEBUG({
559     dbgs() << "\n";
560     dbgs() << "Expr: " << *Expr << "\n";
561     dbgs() << "Region: " << R->getNameStr() << "\n";
562     dbgs() << " -> ";
563   });
564 
565   ValidatorResult Result = Validator.visit(Expr);
566 
567   DEBUG({
568     if (Result.isValid())
569       dbgs() << "VALID\n";
570     dbgs() << "\n";
571   });
572 
573   return Result.isValid();
574 }
575 
576 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
577                                                 const SCEV *Expr,
578                                                 ScalarEvolution &SE,
579                                                 const Value *BaseAddress) {
580   if (isa<SCEVCouldNotCompute>(Expr))
581     return std::vector<const SCEV *>();
582 
583   SCEVValidator Validator(R, SE, BaseAddress);
584   ValidatorResult Result = Validator.visit(Expr);
585   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
586 
587   return Result.getParameters();
588 }
589 
590 std::pair<const SCEV *, const SCEV *>
591 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
592 
593   const SCEV *LeftOver = SE.getConstant(S->getType(), 1);
594   const SCEV *ConstPart = SE.getConstant(S->getType(), 1);
595 
596   const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S);
597   if (!M)
598     return std::make_pair(ConstPart, S);
599 
600   for (const SCEV *Op : M->operands())
601     if (isa<SCEVConstant>(Op))
602       ConstPart = SE.getMulExpr(ConstPart, Op);
603     else
604       LeftOver = SE.getMulExpr(LeftOver, Op);
605 
606   return std::make_pair(ConstPart, LeftOver);
607 }
608 }
609