xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision d5d8f67dc5ec18bfd172dc620a49f68db898512d)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 #include "llvm/Support/Debug.h"
8 
9 #include <vector>
10 
11 using namespace llvm;
12 
13 #define DEBUG_TYPE "polly-scev-validator"
14 
15 namespace SCEVType {
16 /// @brief The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23   // An integer value.
24   INT,
25 
26   // An expression that is constant during the execution of the Scop,
27   // but that may depend on parameters unknown at compile time.
28   PARAM,
29 
30   // An expression that may change during the execution of the SCoP.
31   IV,
32 
33   // An invalid expression.
34   INVALID
35 };
36 }
37 
38 /// @brief The result the validator returns for a SCEV expression.
39 class ValidatorResult {
40   /// @brief The type of the expression
41   SCEVType::TYPE Type;
42 
43   /// @brief The set of Parameters in the expression.
44   std::vector<const SCEV *> Parameters;
45 
46 public:
47   /// @brief The copy constructor
48   ValidatorResult(const ValidatorResult &Source) {
49     Type = Source.Type;
50     Parameters = Source.Parameters;
51   }
52 
53   /// @brief Construct a result with a certain type and no parameters.
54   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
56   }
57 
58   /// @brief Construct a result with a certain type and a single parameter.
59   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60     Parameters.push_back(Expr);
61   }
62 
63   /// @brief Get the type of the ValidatorResult.
64   SCEVType::TYPE getType() { return Type; }
65 
66   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
67   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() { return Type != SCEVType::INVALID; }
71 
72   /// @brief Is the analyzed SCEV of Type IV.
73   bool isIV() { return Type == SCEVType::IV; }
74 
75   /// @brief Is the analyzed SCEV of Type INT.
76   bool isINT() { return Type == SCEVType::INT; }
77 
78   /// @brief Is the analyzed SCEV of Type PARAM.
79   bool isPARAM() { return Type == SCEVType::PARAM; }
80 
81   /// @brief Get the parameters of this validator result.
82   std::vector<const SCEV *> getParameters() { return Parameters; }
83 
84   /// @brief Add the parameters of Source to this result.
85   void addParamsFrom(const ValidatorResult &Source) {
86     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
87                       Source.Parameters.end());
88   }
89 
90   /// @brief Merge a result.
91   ///
92   /// This means to merge the parameters and to set the Type to the most
93   /// specific Type that matches both.
94   void merge(const ValidatorResult &ToMerge) {
95     Type = std::max(Type, ToMerge.Type);
96     addParamsFrom(ToMerge);
97   }
98 
99   void print(raw_ostream &OS) {
100     switch (Type) {
101     case SCEVType::INT:
102       OS << "SCEVType::INT";
103       break;
104     case SCEVType::PARAM:
105       OS << "SCEVType::PARAM";
106       break;
107     case SCEVType::IV:
108       OS << "SCEVType::IV";
109       break;
110     case SCEVType::INVALID:
111       OS << "SCEVType::INVALID";
112       break;
113     }
114   }
115 };
116 
117 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
118   VR.print(OS);
119   return OS;
120 }
121 
122 /// Check if a SCEV is valid in a SCoP.
123 struct SCEVValidator
124     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
125 private:
126   const Region *R;
127   ScalarEvolution &SE;
128   const Value *BaseAddress;
129 
130 public:
131   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
132       : R(R), SE(SE), BaseAddress(BaseAddress) {}
133 
134   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135     return ValidatorResult(SCEVType::INT);
136   }
137 
138   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
139     ValidatorResult Op = visit(Expr->getOperand());
140 
141     switch (Op.getType()) {
142     case SCEVType::INT:
143     case SCEVType::PARAM:
144       // We currently do not represent a truncate expression as an affine
145       // expression. If it is constant during Scop execution, we treat it as a
146       // parameter.
147       return ValidatorResult(SCEVType::PARAM, Expr);
148     case SCEVType::IV:
149       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
150       return ValidatorResult(SCEVType::INVALID);
151     case SCEVType::INVALID:
152       return Op;
153     }
154 
155     llvm_unreachable("Unknown SCEVType");
156   }
157 
158   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
159     ValidatorResult Op = visit(Expr->getOperand());
160 
161     switch (Op.getType()) {
162     case SCEVType::INT:
163     case SCEVType::PARAM:
164       // We currently do not represent a truncate expression as an affine
165       // expression. If it is constant during Scop execution, we treat it as a
166       // parameter.
167       return ValidatorResult(SCEVType::PARAM, Expr);
168     case SCEVType::IV:
169       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
170       return ValidatorResult(SCEVType::INVALID);
171     case SCEVType::INVALID:
172       return Op;
173     }
174 
175     llvm_unreachable("Unknown SCEVType");
176   }
177 
178   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
179     // We currently allow only signed SCEV expressions. In the case of a
180     // signed value, a sign extend is a noop.
181     //
182     // TODO: Reconsider this when we add support for unsigned values.
183     return visit(Expr->getOperand());
184   }
185 
186   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
187     ValidatorResult Return(SCEVType::INT);
188 
189     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
190       ValidatorResult Op = visit(Expr->getOperand(i));
191       Return.merge(Op);
192 
193       // Early exit.
194       if (!Return.isValid())
195         break;
196     }
197 
198     // TODO: Check for NSW and NUW.
199     return Return;
200   }
201 
202   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
203     ValidatorResult Return(SCEVType::INT);
204 
205     bool HasMultipleParams = false;
206 
207     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
208       ValidatorResult Op = visit(Expr->getOperand(i));
209 
210       if (Op.isINT())
211         continue;
212 
213       if (Op.isPARAM() && Return.isPARAM()) {
214         HasMultipleParams = true;
215         continue;
216       }
217 
218       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
219         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
220                      << "\tExpr: " << *Expr << "\n"
221                      << "\tPrevious expression type: " << Return << "\n"
222                      << "\tNext operand (" << Op
223                      << "): " << *Expr->getOperand(i) << "\n");
224 
225         return ValidatorResult(SCEVType::INVALID);
226       }
227 
228       Return.merge(Op);
229     }
230 
231     if (HasMultipleParams && Return.isValid())
232       return ValidatorResult(SCEVType::PARAM, Expr);
233 
234     // TODO: Check for NSW and NUW.
235     return Return;
236   }
237 
238   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
239     ValidatorResult LHS = visit(Expr->getLHS());
240     ValidatorResult RHS = visit(Expr->getRHS());
241 
242     // We currently do not represent an unsigned division as an affine
243     // expression. If the division is constant during Scop execution we treat it
244     // as a parameter, otherwise we bail out.
245     if (LHS.isConstant() && RHS.isConstant())
246       return ValidatorResult(SCEVType::PARAM, Expr);
247 
248     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
249     return ValidatorResult(SCEVType::INVALID);
250   }
251 
252   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
253     if (!Expr->isAffine()) {
254       DEBUG(dbgs() << "INVALID: AddRec is not affine");
255       return ValidatorResult(SCEVType::INVALID);
256     }
257 
258     ValidatorResult Start = visit(Expr->getStart());
259     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
260 
261     if (!Start.isValid())
262       return Start;
263 
264     if (!Recurrence.isValid())
265       return Recurrence;
266 
267     if (R->contains(Expr->getLoop())) {
268       if (Recurrence.isINT()) {
269         ValidatorResult Result(SCEVType::IV);
270         Result.addParamsFrom(Start);
271         return Result;
272       }
273 
274       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
275                       "recurrence part");
276       return ValidatorResult(SCEVType::INVALID);
277     }
278 
279     assert(Start.isConstant() && Recurrence.isConstant() &&
280            "Expected 'Start' and 'Recurrence' to be constant");
281 
282     // Directly generate ValidatorResult for Expr if 'start' is zero.
283     if (Expr->getStart()->isZero())
284       return ValidatorResult(SCEVType::PARAM, Expr);
285 
286     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
287     // if 'start' is not zero.
288     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
289         SE.getConstant(Expr->getStart()->getType(), 0),
290         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
291 
292     ValidatorResult ZeroStartResult =
293         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
294     ZeroStartResult.addParamsFrom(Start);
295 
296     return ZeroStartResult;
297   }
298 
299   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
300     ValidatorResult Return(SCEVType::INT);
301 
302     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
303       ValidatorResult Op = visit(Expr->getOperand(i));
304 
305       if (!Op.isValid())
306         return Op;
307 
308       Return.merge(Op);
309     }
310 
311     return Return;
312   }
313 
314   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
315     // We do not support unsigned operations. If 'Expr' is constant during Scop
316     // execution we treat this as a parameter, otherwise we bail out.
317     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
318       ValidatorResult Op = visit(Expr->getOperand(i));
319 
320       if (!Op.isConstant()) {
321         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
322         return ValidatorResult(SCEVType::INVALID);
323       }
324     }
325 
326     return ValidatorResult(SCEVType::PARAM, Expr);
327   }
328 
329   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
330     if (R->contains(I)) {
331       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
332                       "within the region\n");
333       return ValidatorResult(SCEVType::INVALID);
334     }
335 
336     return ValidatorResult(SCEVType::PARAM, S);
337   }
338 
339   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) {
340     assert(SDiv->getOpcode() == Instruction::SDiv &&
341            "Assumed SDiv instruction!");
342 
343     auto *Divisor = SDiv->getOperand(1);
344     auto *CI = dyn_cast<ConstantInt>(Divisor);
345     if (!CI)
346       return visitGenericInst(SDiv, S);
347 
348     auto *Dividend = SDiv->getOperand(0);
349     auto *DividendSCEV = SE.getSCEV(Dividend);
350     return visit(DividendSCEV);
351   }
352 
353   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
354     Value *V = Expr->getValue();
355 
356     if (!(Expr->getType()->isIntegerTy() || Expr->getType()->isPointerTy())) {
357       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer type");
358       return ValidatorResult(SCEVType::INVALID);
359     }
360 
361     if (isa<UndefValue>(V)) {
362       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
363       return ValidatorResult(SCEVType::INVALID);
364     }
365 
366     if (BaseAddress == V) {
367       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
368       return ValidatorResult(SCEVType::INVALID);
369     }
370 
371     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
372       switch (I->getOpcode()) {
373       case Instruction::SDiv:
374         return visitSDivInstruction(I, Expr);
375       default:
376         return visitGenericInst(I, Expr);
377       }
378     }
379 
380     return ValidatorResult(SCEVType::PARAM, Expr);
381   }
382 };
383 
384 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
385 ///
386 struct SCEVInRegionDependences
387     : public SCEVVisitor<SCEVInRegionDependences, bool> {
388 public:
389   /// Returns true when the SCEV has SSA names defined in region R.
390   static bool hasDependences(const SCEV *S, const Region *R) {
391     SCEVInRegionDependences Ignore(R);
392     return Ignore.visit(S);
393   }
394 
395   SCEVInRegionDependences(const Region *R) : R(R) {}
396 
397   bool visit(const SCEV *Expr) {
398     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
399   }
400 
401   bool visitConstant(const SCEVConstant *Constant) { return false; }
402 
403   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
404     return visit(Expr->getOperand());
405   }
406 
407   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
408     return visit(Expr->getOperand());
409   }
410 
411   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
412     return visit(Expr->getOperand());
413   }
414 
415   bool visitAddExpr(const SCEVAddExpr *Expr) {
416     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
417       if (visit(Expr->getOperand(i)))
418         return true;
419 
420     return false;
421   }
422 
423   bool visitMulExpr(const SCEVMulExpr *Expr) {
424     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
425       if (visit(Expr->getOperand(i)))
426         return true;
427 
428     return false;
429   }
430 
431   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
432     if (visit(Expr->getLHS()))
433       return true;
434 
435     if (visit(Expr->getRHS()))
436       return true;
437 
438     return false;
439   }
440 
441   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
442     if (visit(Expr->getStart()))
443       return true;
444 
445     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
446       if (visit(Expr->getOperand(i)))
447         return true;
448 
449     return false;
450   }
451 
452   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
453     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
454       if (visit(Expr->getOperand(i)))
455         return true;
456 
457     return false;
458   }
459 
460   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
461     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
462       if (visit(Expr->getOperand(i)))
463         return true;
464 
465     return false;
466   }
467 
468   bool visitUnknown(const SCEVUnknown *Expr) {
469     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
470 
471     // Return true when Inst is defined inside the region R.
472     if (Inst && R->contains(Inst))
473       return true;
474 
475     return false;
476   }
477 
478 private:
479   const Region *R;
480 };
481 
482 namespace polly {
483 /// Find all loops referenced in SCEVAddRecExprs.
484 class SCEVFindLoops {
485   SetVector<const Loop *> &Loops;
486 
487 public:
488   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
489 
490   bool follow(const SCEV *S) {
491     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
492       Loops.insert(AddRec->getLoop());
493     return true;
494   }
495   bool isDone() { return false; }
496 };
497 
498 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
499   SCEVFindLoops FindLoops(Loops);
500   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
501   ST.visitAll(Expr);
502 }
503 
504 /// Find all values referenced in SCEVUnknowns.
505 class SCEVFindValues {
506   SetVector<Value *> &Values;
507 
508 public:
509   SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {}
510 
511   bool follow(const SCEV *S) {
512     if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S))
513       Values.insert(Unknown->getValue());
514     return true;
515   }
516   bool isDone() { return false; }
517 };
518 
519 void findValues(const SCEV *Expr, SetVector<Value *> &Values) {
520   SCEVFindValues FindValues(Values);
521   SCEVTraversal<SCEVFindValues> ST(FindValues);
522   ST.visitAll(Expr);
523 }
524 
525 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
526   return SCEVInRegionDependences::hasDependences(Expr, R);
527 }
528 
529 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
530                   const Value *BaseAddress) {
531   if (isa<SCEVCouldNotCompute>(Expr))
532     return false;
533 
534   SCEVValidator Validator(R, SE, BaseAddress);
535   DEBUG({
536     dbgs() << "\n";
537     dbgs() << "Expr: " << *Expr << "\n";
538     dbgs() << "Region: " << R->getNameStr() << "\n";
539     dbgs() << " -> ";
540   });
541 
542   ValidatorResult Result = Validator.visit(Expr);
543 
544   DEBUG({
545     if (Result.isValid())
546       dbgs() << "VALID\n";
547     dbgs() << "\n";
548   });
549 
550   return Result.isValid();
551 }
552 
553 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
554                                                 const SCEV *Expr,
555                                                 ScalarEvolution &SE,
556                                                 const Value *BaseAddress) {
557   if (isa<SCEVCouldNotCompute>(Expr))
558     return std::vector<const SCEV *>();
559 
560   SCEVValidator Validator(R, SE, BaseAddress);
561   ValidatorResult Result = Validator.visit(Expr);
562 
563   return Result.getParameters();
564 }
565 
566 std::pair<const SCEV *, const SCEV *>
567 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
568 
569   const SCEV *LeftOver = SE.getConstant(S->getType(), 1);
570   const SCEV *ConstPart = SE.getConstant(S->getType(), 1);
571 
572   const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S);
573   if (!M)
574     return std::make_pair(ConstPart, S);
575 
576   for (const SCEV *Op : M->operands())
577     if (isa<SCEVConstant>(Op))
578       ConstPart = SE.getMulExpr(ConstPart, Op);
579     else
580       LeftOver = SE.getMulExpr(LeftOver, Op);
581 
582   return std::make_pair(ConstPart, LeftOver);
583 }
584 }
585