xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision bd93df937a6441db4aff67191ca0bb486554c34b)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 } // namespace SCEVType
36 
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult final {
39   /// The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// Get the parameters of this validator result.
81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
115 raw_ostream &operator<<(raw_ostream &OS, ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
120 /// Check if a SCEV is valid in a SCoP.
121 class SCEVValidator : public SCEVVisitor<SCEVValidator, ValidatorResult> {
122 private:
123   const Region *R;
124   Loop *Scope;
125   ScalarEvolution &SE;
126   InvariantLoadsSetTy *ILS;
127 
128 public:
129   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
130                 InvariantLoadsSetTy *ILS)
131       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
132 
133   ValidatorResult visitConstant(const SCEVConstant *Constant) {
134     return ValidatorResult(SCEVType::INT);
135   }
136 
137   ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
138                                                 const SCEV *Operand) {
139     ValidatorResult Op = visit(Operand);
140     auto Type = Op.getType();
141 
142     // If unsigned operations are allowed return the operand, otherwise
143     // check if we can model the expression without unsigned assumptions.
144     if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
145       return Op;
146 
147     if (Type == SCEVType::IV)
148       return ValidatorResult(SCEVType::INVALID);
149     return ValidatorResult(SCEVType::PARAM, Expr);
150   }
151 
152   ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
153     return visit(Expr->getOperand());
154   }
155 
156   ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
157     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
158   }
159 
160   ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
161     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
162   }
163 
164   ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
165     return visit(Expr->getOperand());
166   }
167 
168   ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
169     ValidatorResult Return(SCEVType::INT);
170 
171     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
172       ValidatorResult Op = visit(Expr->getOperand(i));
173       Return.merge(Op);
174 
175       // Early exit.
176       if (!Return.isValid())
177         break;
178     }
179 
180     return Return;
181   }
182 
183   ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
184     ValidatorResult Return(SCEVType::INT);
185 
186     bool HasMultipleParams = false;
187 
188     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
189       ValidatorResult Op = visit(Expr->getOperand(i));
190 
191       if (Op.isINT())
192         continue;
193 
194       if (Op.isPARAM() && Return.isPARAM()) {
195         HasMultipleParams = true;
196         continue;
197       }
198 
199       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
200         LLVM_DEBUG(
201             dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
202                    << "\tExpr: " << *Expr << "\n"
203                    << "\tPrevious expression type: " << Return << "\n"
204                    << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
205                    << "\n");
206 
207         return ValidatorResult(SCEVType::INVALID);
208       }
209 
210       Return.merge(Op);
211     }
212 
213     if (HasMultipleParams && Return.isValid())
214       return ValidatorResult(SCEVType::PARAM, Expr);
215 
216     return Return;
217   }
218 
219   ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
220     if (!Expr->isAffine()) {
221       LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
222       return ValidatorResult(SCEVType::INVALID);
223     }
224 
225     ValidatorResult Start = visit(Expr->getStart());
226     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
227 
228     if (!Start.isValid())
229       return Start;
230 
231     if (!Recurrence.isValid())
232       return Recurrence;
233 
234     auto *L = Expr->getLoop();
235     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
236       LLVM_DEBUG(
237           dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
238                     "non-affine subregion or has a non-synthesizable exit "
239                     "value.");
240       return ValidatorResult(SCEVType::INVALID);
241     }
242 
243     if (R->contains(L)) {
244       if (Recurrence.isINT()) {
245         ValidatorResult Result(SCEVType::IV);
246         Result.addParamsFrom(Start);
247         return Result;
248       }
249 
250       LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
251                            "recurrence part");
252       return ValidatorResult(SCEVType::INVALID);
253     }
254 
255     assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
256 
257     // Directly generate ValidatorResult for Expr if 'start' is zero.
258     if (Expr->getStart()->isZero())
259       return ValidatorResult(SCEVType::PARAM, Expr);
260 
261     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
262     // if 'start' is not zero.
263     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
264         SE.getConstant(Expr->getStart()->getType(), 0),
265         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
266 
267     ValidatorResult ZeroStartResult =
268         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
269     ZeroStartResult.addParamsFrom(Start);
270 
271     return ZeroStartResult;
272   }
273 
274   ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
275     ValidatorResult Return(SCEVType::INT);
276 
277     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
278       ValidatorResult Op = visit(Expr->getOperand(i));
279 
280       if (!Op.isValid())
281         return Op;
282 
283       Return.merge(Op);
284     }
285 
286     return Return;
287   }
288 
289   ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
290     ValidatorResult Return(SCEVType::INT);
291 
292     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
293       ValidatorResult Op = visit(Expr->getOperand(i));
294 
295       if (!Op.isValid())
296         return Op;
297 
298       Return.merge(Op);
299     }
300 
301     return Return;
302   }
303 
304   ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
305     // We do not support unsigned max operations. If 'Expr' is constant during
306     // Scop execution we treat this as a parameter, otherwise we bail out.
307     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
308       ValidatorResult Op = visit(Expr->getOperand(i));
309 
310       if (!Op.isConstant()) {
311         LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
312         return ValidatorResult(SCEVType::INVALID);
313       }
314     }
315 
316     return ValidatorResult(SCEVType::PARAM, Expr);
317   }
318 
319   ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
320     // We do not support unsigned min operations. If 'Expr' is constant during
321     // Scop execution we treat this as a parameter, otherwise we bail out.
322     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
323       ValidatorResult Op = visit(Expr->getOperand(i));
324 
325       if (!Op.isConstant()) {
326         LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
327         return ValidatorResult(SCEVType::INVALID);
328       }
329     }
330 
331     return ValidatorResult(SCEVType::PARAM, Expr);
332   }
333 
334   ValidatorResult visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
335     // We do not support unsigned min operations. If 'Expr' is constant during
336     // Scop execution we treat this as a parameter, otherwise we bail out.
337     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
338       ValidatorResult Op = visit(Expr->getOperand(i));
339 
340       if (!Op.isConstant()) {
341         LLVM_DEBUG(
342             dbgs()
343             << "INVALID: SCEVSequentialUMinExpr has a non-constant operand");
344         return ValidatorResult(SCEVType::INVALID);
345       }
346     }
347 
348     return ValidatorResult(SCEVType::PARAM, Expr);
349   }
350 
351   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
352     if (R->contains(I)) {
353       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
354                            "within the region\n");
355       return ValidatorResult(SCEVType::INVALID);
356     }
357 
358     return ValidatorResult(SCEVType::PARAM, S);
359   }
360 
361   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
362     if (R->contains(I) && ILS) {
363       ILS->insert(cast<LoadInst>(I));
364       return ValidatorResult(SCEVType::PARAM, S);
365     }
366 
367     return visitGenericInst(I, S);
368   }
369 
370   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
371                                 const SCEV *DivExpr,
372                                 Instruction *SDiv = nullptr) {
373 
374     // First check if we might be able to model the division, thus if the
375     // divisor is constant. If so, check the dividend, otherwise check if
376     // the whole division can be seen as a parameter.
377     if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
378       return visit(Dividend);
379 
380     // For signed divisions use the SDiv instruction to check for a parameter
381     // division, for unsigned divisions check the operands.
382     if (SDiv)
383       return visitGenericInst(SDiv, DivExpr);
384 
385     ValidatorResult LHS = visit(Dividend);
386     ValidatorResult RHS = visit(Divisor);
387     if (LHS.isConstant() && RHS.isConstant())
388       return ValidatorResult(SCEVType::PARAM, DivExpr);
389 
390     LLVM_DEBUG(
391         dbgs() << "INVALID: unsigned division of non-constant expressions");
392     return ValidatorResult(SCEVType::INVALID);
393   }
394 
395   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
396     if (!PollyAllowUnsignedOperations)
397       return ValidatorResult(SCEVType::INVALID);
398 
399     auto *Dividend = Expr->getLHS();
400     auto *Divisor = Expr->getRHS();
401     return visitDivision(Dividend, Divisor, Expr);
402   }
403 
404   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
405     assert(SDiv->getOpcode() == Instruction::SDiv &&
406            "Assumed SDiv instruction!");
407 
408     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
409     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
410     return visitDivision(Dividend, Divisor, Expr, SDiv);
411   }
412 
413   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
414     assert(SRem->getOpcode() == Instruction::SRem &&
415            "Assumed SRem instruction!");
416 
417     auto *Divisor = SRem->getOperand(1);
418     auto *CI = dyn_cast<ConstantInt>(Divisor);
419     if (!CI || CI->isZeroValue())
420       return visitGenericInst(SRem, S);
421 
422     auto *Dividend = SRem->getOperand(0);
423     auto *DividendSCEV = SE.getSCEV(Dividend);
424     return visit(DividendSCEV);
425   }
426 
427   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
428     Value *V = Expr->getValue();
429 
430     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
431       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
432       return ValidatorResult(SCEVType::INVALID);
433     }
434 
435     if (isa<UndefValue>(V)) {
436       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
437       return ValidatorResult(SCEVType::INVALID);
438     }
439 
440     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
441       switch (I->getOpcode()) {
442       case Instruction::IntToPtr:
443         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
444       case Instruction::Load:
445         return visitLoadInstruction(I, Expr);
446       case Instruction::SDiv:
447         return visitSDivInstruction(I, Expr);
448       case Instruction::SRem:
449         return visitSRemInstruction(I, Expr);
450       default:
451         return visitGenericInst(I, Expr);
452       }
453     }
454 
455     if (Expr->getType()->isPointerTy()) {
456       if (isa<ConstantPointerNull>(V))
457         return ValidatorResult(SCEVType::INT); // "int"
458     }
459 
460     return ValidatorResult(SCEVType::PARAM, Expr);
461   }
462 };
463 
464 /// Check whether a SCEV refers to an SSA name defined inside a region.
465 class SCEVInRegionDependences final {
466   const Region *R;
467   Loop *Scope;
468   const InvariantLoadsSetTy &ILS;
469   bool AllowLoops;
470   bool HasInRegionDeps = false;
471 
472 public:
473   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
474                           const InvariantLoadsSetTy &ILS)
475       : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
476 
477   bool follow(const SCEV *S) {
478     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
479       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
480 
481       if (Inst) {
482         // When we invariant load hoist a load, we first make sure that there
483         // can be no dependences created by it in the Scop region. So, we should
484         // not consider scalar dependences to `LoadInst`s that are invariant
485         // load hoisted.
486         //
487         // If this check is not present, then we create data dependences which
488         // are strictly not necessary by tracking the invariant load as a
489         // scalar.
490         LoadInst *LI = dyn_cast<LoadInst>(Inst);
491         if (LI && ILS.contains(LI))
492           return false;
493       }
494 
495       // Return true when Inst is defined inside the region R.
496       if (!Inst || !R->contains(Inst))
497         return true;
498 
499       HasInRegionDeps = true;
500       return false;
501     }
502 
503     if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
504       if (AllowLoops)
505         return true;
506 
507       auto *L = AddRec->getLoop();
508       if (R->contains(L) && !L->contains(Scope)) {
509         HasInRegionDeps = true;
510         return false;
511       }
512     }
513 
514     return true;
515   }
516   bool isDone() { return false; }
517   bool hasDependences() { return HasInRegionDeps; }
518 };
519 
520 /// Find all loops referenced in SCEVAddRecExprs.
521 class SCEVFindLoops final {
522   SetVector<const Loop *> &Loops;
523 
524 public:
525   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
526 
527   bool follow(const SCEV *S) {
528     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
529       Loops.insert(AddRec->getLoop());
530     return true;
531   }
532   bool isDone() { return false; }
533 };
534 
535 void polly::findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
536   SCEVFindLoops FindLoops(Loops);
537   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
538   ST.visitAll(Expr);
539 }
540 
541 /// Find all values referenced in SCEVUnknowns.
542 class SCEVFindValues final {
543   ScalarEvolution &SE;
544   SetVector<Value *> &Values;
545 
546 public:
547   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
548       : SE(SE), Values(Values) {}
549 
550   bool follow(const SCEV *S) {
551     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
552     if (!Unknown)
553       return true;
554 
555     Values.insert(Unknown->getValue());
556     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
557     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
558                   Inst->getOpcode() != Instruction::SDiv))
559       return false;
560 
561     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
562     if (!isa<SCEVConstant>(Dividend))
563       return false;
564 
565     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
566     SCEVFindValues FindValues(SE, Values);
567     SCEVTraversal<SCEVFindValues> ST(FindValues);
568     ST.visitAll(Dividend);
569     ST.visitAll(Divisor);
570 
571     return false;
572   }
573   bool isDone() { return false; }
574 };
575 
576 void polly::findValues(const SCEV *Expr, ScalarEvolution &SE,
577                        SetVector<Value *> &Values) {
578   SCEVFindValues FindValues(SE, Values);
579   SCEVTraversal<SCEVFindValues> ST(FindValues);
580   ST.visitAll(Expr);
581 }
582 
583 bool polly::hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
584                                       llvm::Loop *Scope, bool AllowLoops,
585                                       const InvariantLoadsSetTy &ILS) {
586   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
587   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
588   ST.visitAll(Expr);
589   return InRegionDeps.hasDependences();
590 }
591 
592 bool polly::isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
593                          ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
594   if (isa<SCEVCouldNotCompute>(Expr))
595     return false;
596 
597   SCEVValidator Validator(R, Scope, SE, ILS);
598   LLVM_DEBUG({
599     dbgs() << "\n";
600     dbgs() << "Expr: " << *Expr << "\n";
601     dbgs() << "Region: " << R->getNameStr() << "\n";
602     dbgs() << " -> ";
603   });
604 
605   ValidatorResult Result = Validator.visit(Expr);
606 
607   LLVM_DEBUG({
608     if (Result.isValid())
609       dbgs() << "VALID\n";
610     dbgs() << "\n";
611   });
612 
613   return Result.isValid();
614 }
615 
616 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
617                          ScalarEvolution &SE, ParameterSetTy &Params) {
618   auto *E = SE.getSCEV(V);
619   if (isa<SCEVCouldNotCompute>(E))
620     return false;
621 
622   SCEVValidator Validator(R, Scope, SE, nullptr);
623   ValidatorResult Result = Validator.visit(E);
624   if (!Result.isValid())
625     return false;
626 
627   auto ResultParams = Result.getParameters();
628   Params.insert(ResultParams.begin(), ResultParams.end());
629 
630   return true;
631 }
632 
633 bool polly::isAffineConstraint(Value *V, const Region *R, Loop *Scope,
634                                ScalarEvolution &SE, ParameterSetTy &Params,
635                                bool OrExpr) {
636   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
637     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
638                               true) &&
639            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
640   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
641     auto Opcode = BinOp->getOpcode();
642     if (Opcode == Instruction::And || Opcode == Instruction::Or)
643       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
644                                 false) &&
645              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
646                                 false);
647     /* Fall through */
648   }
649 
650   if (!OrExpr)
651     return false;
652 
653   return ::isAffineExpr(V, R, Scope, SE, Params);
654 }
655 
656 ParameterSetTy polly::getParamsInAffineExpr(const Region *R, Loop *Scope,
657                                             const SCEV *Expr,
658                                             ScalarEvolution &SE) {
659   if (isa<SCEVCouldNotCompute>(Expr))
660     return ParameterSetTy();
661 
662   InvariantLoadsSetTy ILS;
663   SCEVValidator Validator(R, Scope, SE, &ILS);
664   ValidatorResult Result = Validator.visit(Expr);
665   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
666 
667   return Result.getParameters();
668 }
669 
670 std::pair<const SCEVConstant *, const SCEV *>
671 polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
672   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
673 
674   if (auto *Constant = dyn_cast<SCEVConstant>(S))
675     return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
676 
677   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
678   if (AddRec) {
679     auto *StartExpr = AddRec->getStart();
680     if (StartExpr->isZero()) {
681       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
682       auto *LeftOverAddRec =
683           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
684                            AddRec->getNoWrapFlags());
685       return std::make_pair(StepPair.first, LeftOverAddRec);
686     }
687     return std::make_pair(ConstPart, S);
688   }
689 
690   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
691     SmallVector<const SCEV *, 4> LeftOvers;
692     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
693     auto *Factor = Op0Pair.first;
694     if (SE.isKnownNegative(Factor)) {
695       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
696       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
697     } else {
698       LeftOvers.push_back(Op0Pair.second);
699     }
700 
701     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
702       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
703       // TODO: Use something smarter than equality here, e.g., gcd.
704       if (Factor == OpUPair.first)
705         LeftOvers.push_back(OpUPair.second);
706       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
707         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
708       else
709         return std::make_pair(ConstPart, S);
710     }
711 
712     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
713     return std::make_pair(Factor, NewAdd);
714   }
715 
716   auto *Mul = dyn_cast<SCEVMulExpr>(S);
717   if (!Mul)
718     return std::make_pair(ConstPart, S);
719 
720   SmallVector<const SCEV *, 4> LeftOvers;
721   for (auto *Op : Mul->operands())
722     if (isa<SCEVConstant>(Op))
723       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
724     else
725       LeftOvers.push_back(Op);
726 
727   return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
728 }
729 
730 const SCEV *polly::tryForwardThroughPHI(const SCEV *Expr, Region &R,
731                                         ScalarEvolution &SE,
732                                         ScopDetection *SD) {
733   if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
734     Value *V = Unknown->getValue();
735     auto *PHI = dyn_cast<PHINode>(V);
736     if (!PHI)
737       return Expr;
738 
739     Value *Final = nullptr;
740 
741     for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
742       BasicBlock *Incoming = PHI->getIncomingBlock(i);
743       if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
744         continue;
745       if (Final)
746         return Expr;
747       Final = PHI->getIncomingValue(i);
748     }
749 
750     if (Final)
751       return SE.getSCEV(Final);
752   }
753   return Expr;
754 }
755 
756 Value *polly::getUniqueNonErrorValue(PHINode *PHI, Region *R,
757                                      ScopDetection *SD) {
758   Value *V = nullptr;
759   for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
760     BasicBlock *BB = PHI->getIncomingBlock(i);
761     if (!SD->isErrorBlock(*BB, *R)) {
762       if (V)
763         return nullptr;
764       V = PHI->getIncomingValue(i);
765     }
766   }
767 
768   return V;
769 }
770