xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 5130c849aa3ac569887503a63bb790418bffea40)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 #include "llvm/Support/Debug.h"
8 
9 #include <vector>
10 
11 using namespace llvm;
12 
13 #define DEBUG_TYPE "polly-scev-validator"
14 
15 namespace SCEVType {
16 /// @brief The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23   // An integer value.
24   INT,
25 
26   // An expression that is constant during the execution of the Scop,
27   // but that may depend on parameters unknown at compile time.
28   PARAM,
29 
30   // An expression that may change during the execution of the SCoP.
31   IV,
32 
33   // An invalid expression.
34   INVALID
35 };
36 }
37 
38 /// @brief The result the validator returns for a SCEV expression.
39 class ValidatorResult {
40   /// @brief The type of the expression
41   SCEVType::TYPE Type;
42 
43   /// @brief The set of Parameters in the expression.
44   std::vector<const SCEV *> Parameters;
45 
46 public:
47   /// @brief The copy constructor
48   ValidatorResult(const ValidatorResult &Source) {
49     Type = Source.Type;
50     Parameters = Source.Parameters;
51   }
52 
53   /// @brief Construct a result with a certain type and no parameters.
54   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
56   }
57 
58   /// @brief Construct a result with a certain type and a single parameter.
59   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60     Parameters.push_back(Expr);
61   }
62 
63   /// @brief Get the type of the ValidatorResult.
64   SCEVType::TYPE getType() { return Type; }
65 
66   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
67   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() { return Type != SCEVType::INVALID; }
71 
72   /// @brief Is the analyzed SCEV of Type IV.
73   bool isIV() { return Type == SCEVType::IV; }
74 
75   /// @brief Is the analyzed SCEV of Type INT.
76   bool isINT() { return Type == SCEVType::INT; }
77 
78   /// @brief Is the analyzed SCEV of Type PARAM.
79   bool isPARAM() { return Type == SCEVType::PARAM; }
80 
81   /// @brief Get the parameters of this validator result.
82   std::vector<const SCEV *> getParameters() { return Parameters; }
83 
84   /// @brief Add the parameters of Source to this result.
85   void addParamsFrom(const ValidatorResult &Source) {
86     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
87                       Source.Parameters.end());
88   }
89 
90   /// @brief Merge a result.
91   ///
92   /// This means to merge the parameters and to set the Type to the most
93   /// specific Type that matches both.
94   ValidatorResult &merge(const ValidatorResult &ToMerge) {
95     Type = std::max(Type, ToMerge.Type);
96     addParamsFrom(ToMerge);
97     return *this;
98   }
99 
100   void print(raw_ostream &OS) {
101     switch (Type) {
102     case SCEVType::INT:
103       OS << "SCEVType::INT";
104       break;
105     case SCEVType::PARAM:
106       OS << "SCEVType::PARAM";
107       break;
108     case SCEVType::IV:
109       OS << "SCEVType::IV";
110       break;
111     case SCEVType::INVALID:
112       OS << "SCEVType::INVALID";
113       break;
114     }
115   }
116 };
117 
118 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
119   VR.print(OS);
120   return OS;
121 }
122 
123 /// Check if a SCEV is valid in a SCoP.
124 struct SCEVValidator
125     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
126 private:
127   const Region *R;
128   ScalarEvolution &SE;
129   const Value *BaseAddress;
130 
131 public:
132   SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress)
133       : R(R), SE(SE), BaseAddress(BaseAddress) {}
134 
135   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
136     return ValidatorResult(SCEVType::INT);
137   }
138 
139   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
140     ValidatorResult Op = visit(Expr->getOperand());
141 
142     switch (Op.getType()) {
143     case SCEVType::INT:
144     case SCEVType::PARAM:
145       // We currently do not represent a truncate expression as an affine
146       // expression. If it is constant during Scop execution, we treat it as a
147       // parameter.
148       return ValidatorResult(SCEVType::PARAM, Expr);
149     case SCEVType::IV:
150       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
151       return ValidatorResult(SCEVType::INVALID);
152     case SCEVType::INVALID:
153       return Op;
154     }
155 
156     llvm_unreachable("Unknown SCEVType");
157   }
158 
159   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
160 
161     // Pattern matching rules to capture some bit and modulo computations:
162     //
163     //       EXP   %  2^C      <==>
164     // [A] (i + c) & (2^C - 1)  ==> zext iC {c,+,1}<%for_i> to IXX
165     // [B] (p + q) & (2^C - 1)  ==> zext iC (trunc iXX %p_add_q to iC) to iXX
166     // [C] (i + p) & (2^C - 1)  ==> zext iC {p & (2^C - 1),+,1}<%for_i> to iXX
167     //                          ==> zext iC {trunc iXX %p to iC,+,1}<%for_i> to
168 
169     // Check for [A] and [C].
170     const SCEV *OpS = Expr->getOperand();
171     if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpS)) {
172       const SCEV *OpARStart = OpAR->getStart();
173 
174       // Special case for [C].
175       if (auto *OpARStartTR = dyn_cast<SCEVTruncateExpr>(OpARStart))
176         OpARStart = OpARStartTR->getOperand();
177 
178       ValidatorResult OpARStartVR = visit(OpARStart);
179       if (OpARStartVR.isConstant() && OpAR->getStepRecurrence(SE)->isOne())
180         return OpARStartVR;
181     }
182 
183     // Check for [B].
184     if (auto *OpTR = dyn_cast<SCEVTruncateExpr>(OpS)) {
185       ValidatorResult OpTRVR = visit(OpTR->getOperand());
186       if (OpTRVR.isConstant())
187         return OpTRVR;
188     }
189 
190     ValidatorResult Op = visit(OpS);
191     switch (Op.getType()) {
192     case SCEVType::INT:
193     case SCEVType::PARAM:
194       // We currently do not represent a truncate expression as an affine
195       // expression. If it is constant during Scop execution, we treat it as a
196       // parameter.
197       return ValidatorResult(SCEVType::PARAM, Expr);
198     case SCEVType::IV:
199       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
200       return ValidatorResult(SCEVType::INVALID);
201     case SCEVType::INVALID:
202       return Op;
203     }
204 
205     llvm_unreachable("Unknown SCEVType");
206   }
207 
208   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
209     // We currently allow only signed SCEV expressions. In the case of a
210     // signed value, a sign extend is a noop.
211     //
212     // TODO: Reconsider this when we add support for unsigned values.
213     return visit(Expr->getOperand());
214   }
215 
216   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
217     ValidatorResult Return(SCEVType::INT);
218 
219     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
220       ValidatorResult Op = visit(Expr->getOperand(i));
221       Return.merge(Op);
222 
223       // Early exit.
224       if (!Return.isValid())
225         break;
226     }
227 
228     // TODO: Check for NSW and NUW.
229     return Return;
230   }
231 
232   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
233     ValidatorResult Return(SCEVType::INT);
234 
235     bool HasMultipleParams = false;
236 
237     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
238       ValidatorResult Op = visit(Expr->getOperand(i));
239 
240       if (Op.isINT())
241         continue;
242 
243       if (Op.isPARAM() && Return.isPARAM()) {
244         HasMultipleParams = true;
245         continue;
246       }
247 
248       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
249         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
250                      << "\tExpr: " << *Expr << "\n"
251                      << "\tPrevious expression type: " << Return << "\n"
252                      << "\tNext operand (" << Op
253                      << "): " << *Expr->getOperand(i) << "\n");
254 
255         return ValidatorResult(SCEVType::INVALID);
256       }
257 
258       Return.merge(Op);
259     }
260 
261     if (HasMultipleParams)
262       return ValidatorResult(SCEVType::PARAM, Expr);
263 
264     // TODO: Check for NSW and NUW.
265     return Return;
266   }
267 
268   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
269     ValidatorResult LHS = visit(Expr->getLHS());
270     ValidatorResult RHS = visit(Expr->getRHS());
271 
272     // We currently do not represent an unsigned division as an affine
273     // expression. If the division is constant during Scop execution we treat it
274     // as a parameter, otherwise we bail out.
275     if (LHS.isConstant() && RHS.isConstant())
276       return ValidatorResult(SCEVType::PARAM, Expr);
277 
278     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
279     return ValidatorResult(SCEVType::INVALID);
280   }
281 
282   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
283     if (!Expr->isAffine()) {
284       DEBUG(dbgs() << "INVALID: AddRec is not affine");
285       return ValidatorResult(SCEVType::INVALID);
286     }
287 
288     ValidatorResult Start = visit(Expr->getStart());
289     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
290 
291     if (!Start.isValid())
292       return Start;
293 
294     if (!Recurrence.isValid())
295       return Recurrence;
296 
297     if (R->contains(Expr->getLoop())) {
298       if (Recurrence.isINT()) {
299         ValidatorResult Result(SCEVType::IV);
300         Result.addParamsFrom(Start);
301         return Result;
302       }
303 
304       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
305                       "recurrence part");
306       return ValidatorResult(SCEVType::INVALID);
307     }
308 
309     assert(Start.isConstant() && Recurrence.isConstant() &&
310            "Expected 'Start' and 'Recurrence' to be constant");
311 
312     // Directly generate ValidatorResult for Expr if 'start' is zero.
313     if (Expr->getStart()->isZero())
314       return ValidatorResult(SCEVType::PARAM, Expr);
315 
316     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
317     // if 'start' is not zero.
318     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
319         SE.getConstant(Expr->getStart()->getType(), 0),
320         Expr->getStepRecurrence(SE), Expr->getLoop(), SCEV::FlagAnyWrap);
321 
322     ValidatorResult ZeroStartResult =
323         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
324     ZeroStartResult.addParamsFrom(Start);
325 
326     return ZeroStartResult;
327   }
328 
329   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
330     ValidatorResult Return(SCEVType::INT);
331 
332     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
333       ValidatorResult Op = visit(Expr->getOperand(i));
334 
335       if (!Op.isValid())
336         return Op;
337 
338       Return.merge(Op);
339     }
340 
341     return Return;
342   }
343 
344   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
345     // We do not support unsigned operations. If 'Expr' is constant during Scop
346     // execution we treat this as a parameter, otherwise we bail out.
347     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
348       ValidatorResult Op = visit(Expr->getOperand(i));
349 
350       if (!Op.isConstant()) {
351         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
352         return ValidatorResult(SCEVType::INVALID);
353       }
354     }
355 
356     return ValidatorResult(SCEVType::PARAM, Expr);
357   }
358 
359   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
360     Value *V = Expr->getValue();
361 
362     // We currently only support integer types. It may be useful to support
363     // pointer types, e.g. to support code like:
364     //
365     //   if (A)
366     //     A[i] = 1;
367     //
368     // See test/CodeGen/20120316-InvalidCast.ll
369     if (!Expr->getType()->isIntegerTy()) {
370       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer type");
371       return ValidatorResult(SCEVType::INVALID);
372     }
373 
374     if (isa<UndefValue>(V)) {
375       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
376       return ValidatorResult(SCEVType::INVALID);
377     }
378 
379     if (auto *I = dyn_cast<Instruction>(Expr->getValue())) {
380       if (I->getOpcode() == Instruction::SRem) {
381 
382         ValidatorResult Op0 = visit(SE.getSCEV(I->getOperand(0)));
383         if (!Op0.isValid())
384           return ValidatorResult(SCEVType::INVALID);
385 
386         ValidatorResult Op1 = visit(SE.getSCEV(I->getOperand(1)));
387         if (!Op1.isValid() || !Op1.isINT())
388           return ValidatorResult(SCEVType::INVALID);
389 
390         Op0.merge(Op1);
391         return Op0;
392       }
393 
394       if (R->contains(I)) {
395         DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
396                         "within the region\n");
397         return ValidatorResult(SCEVType::INVALID);
398       }
399     }
400 
401     if (BaseAddress == V) {
402       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
403       return ValidatorResult(SCEVType::INVALID);
404     }
405 
406     return ValidatorResult(SCEVType::PARAM, Expr);
407   }
408 };
409 
410 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
411 ///
412 struct SCEVInRegionDependences
413     : public SCEVVisitor<SCEVInRegionDependences, bool> {
414 public:
415   /// Returns true when the SCEV has SSA names defined in region R.
416   static bool hasDependences(const SCEV *S, const Region *R) {
417     SCEVInRegionDependences Ignore(R);
418     return Ignore.visit(S);
419   }
420 
421   SCEVInRegionDependences(const Region *R) : R(R) {}
422 
423   bool visit(const SCEV *Expr) {
424     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
425   }
426 
427   bool visitConstant(const SCEVConstant *Constant) { return false; }
428 
429   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
430     return visit(Expr->getOperand());
431   }
432 
433   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
434     return visit(Expr->getOperand());
435   }
436 
437   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
438     return visit(Expr->getOperand());
439   }
440 
441   bool visitAddExpr(const SCEVAddExpr *Expr) {
442     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
443       if (visit(Expr->getOperand(i)))
444         return true;
445 
446     return false;
447   }
448 
449   bool visitMulExpr(const SCEVMulExpr *Expr) {
450     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
451       if (visit(Expr->getOperand(i)))
452         return true;
453 
454     return false;
455   }
456 
457   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
458     if (visit(Expr->getLHS()))
459       return true;
460 
461     if (visit(Expr->getRHS()))
462       return true;
463 
464     return false;
465   }
466 
467   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
468     if (visit(Expr->getStart()))
469       return true;
470 
471     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
472       if (visit(Expr->getOperand(i)))
473         return true;
474 
475     return false;
476   }
477 
478   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
479     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
480       if (visit(Expr->getOperand(i)))
481         return true;
482 
483     return false;
484   }
485 
486   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
487     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
488       if (visit(Expr->getOperand(i)))
489         return true;
490 
491     return false;
492   }
493 
494   bool visitUnknown(const SCEVUnknown *Expr) {
495     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
496 
497     // Return true when Inst is defined inside the region R.
498     if (Inst && R->contains(Inst))
499       return true;
500 
501     return false;
502   }
503 
504 private:
505   const Region *R;
506 };
507 
508 namespace polly {
509 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) {
510   return SCEVInRegionDependences::hasDependences(Expr, R);
511 }
512 
513 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
514                   const Value *BaseAddress) {
515   if (isa<SCEVCouldNotCompute>(Expr))
516     return false;
517 
518   SCEVValidator Validator(R, SE, BaseAddress);
519   DEBUG(dbgs() << "\n"; dbgs() << "Expr: " << *Expr << "\n";
520         dbgs() << "Region: " << R->getNameStr() << "\n"; dbgs() << " -> ");
521 
522   ValidatorResult Result = Validator.visit(Expr);
523 
524   DEBUG(if (Result.isValid()) dbgs() << "VALID\n"; dbgs() << "\n";);
525 
526   return Result.isValid();
527 }
528 
529 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R,
530                                                 const SCEV *Expr,
531                                                 ScalarEvolution &SE,
532                                                 const Value *BaseAddress) {
533   if (isa<SCEVCouldNotCompute>(Expr))
534     return std::vector<const SCEV *>();
535 
536   SCEVValidator Validator(R, SE, BaseAddress);
537   ValidatorResult Result = Validator.visit(Expr);
538 
539   return Result.getParameters();
540 }
541 }
542