xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 09eb4451d2a94995fe985d21ec802b7bb262479c)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 #include <vector>
9 
10 using namespace llvm;
11 using namespace polly;
12 
13 #define DEBUG_TYPE "polly-scev-validator"
14 
15 namespace SCEVType {
16 /// @brief The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23   // An integer value.
24   INT,
25 
26   // An expression that is constant during the execution of the Scop,
27   // but that may depend on parameters unknown at compile time.
28   PARAM,
29 
30   // An expression that may change during the execution of the SCoP.
31   IV,
32 
33   // An invalid expression.
34   INVALID
35 };
36 }
37 
38 /// @brief The result the validator returns for a SCEV expression.
39 class ValidatorResult {
40   /// @brief The type of the expression
41   SCEVType::TYPE Type;
42 
43   /// @brief The set of Parameters in the expression.
44   std::vector<const SCEV *> Parameters;
45 
46 public:
47   /// @brief The copy constructor
48   ValidatorResult(const ValidatorResult &Source) {
49     Type = Source.Type;
50     Parameters = Source.Parameters;
51   }
52 
53   /// @brief Construct a result with a certain type and no parameters.
54   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
56   }
57 
58   /// @brief Construct a result with a certain type and a single parameter.
59   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60     Parameters.push_back(Expr);
61   }
62 
63   /// @brief Get the type of the ValidatorResult.
64   SCEVType::TYPE getType() { return Type; }
65 
66   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
67   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 
69   /// @brief Is the analyzed SCEV valid.
70   bool isValid() { return Type != SCEVType::INVALID; }
71 
72   /// @brief Is the analyzed SCEV of Type IV.
73   bool isIV() { return Type == SCEVType::IV; }
74 
75   /// @brief Is the analyzed SCEV of Type INT.
76   bool isINT() { return Type == SCEVType::INT; }
77 
78   /// @brief Is the analyzed SCEV of Type PARAM.
79   bool isPARAM() { return Type == SCEVType::PARAM; }
80 
81   /// @brief Get the parameters of this validator result.
82   std::vector<const SCEV *> getParameters() { return Parameters; }
83 
84   /// @brief Add the parameters of Source to this result.
85   void addParamsFrom(const ValidatorResult &Source) {
86     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
87                       Source.Parameters.end());
88   }
89 
90   /// @brief Merge a result.
91   ///
92   /// This means to merge the parameters and to set the Type to the most
93   /// specific Type that matches both.
94   void merge(const ValidatorResult &ToMerge) {
95     Type = std::max(Type, ToMerge.Type);
96     addParamsFrom(ToMerge);
97   }
98 
99   void print(raw_ostream &OS) {
100     switch (Type) {
101     case SCEVType::INT:
102       OS << "SCEVType::INT";
103       break;
104     case SCEVType::PARAM:
105       OS << "SCEVType::PARAM";
106       break;
107     case SCEVType::IV:
108       OS << "SCEVType::IV";
109       break;
110     case SCEVType::INVALID:
111       OS << "SCEVType::INVALID";
112       break;
113     }
114   }
115 };
116 
117 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
118   VR.print(OS);
119   return OS;
120 }
121 
122 /// Check if a SCEV is valid in a SCoP.
123 struct SCEVValidator
124     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
125 private:
126   const Region *R;
127   Loop *Scope;
128   ScalarEvolution &SE;
129   const Value *BaseAddress;
130   InvariantLoadsSetTy *ILS;
131 
132 public:
133   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
134                 const Value *BaseAddress, InvariantLoadsSetTy *ILS)
135       : R(R), Scope(Scope), SE(SE), BaseAddress(BaseAddress), ILS(ILS) {}
136 
137   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
138     return ValidatorResult(SCEVType::INT);
139   }
140 
141   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
142     ValidatorResult Op = visit(Expr->getOperand());
143 
144     switch (Op.getType()) {
145     case SCEVType::INT:
146     case SCEVType::PARAM:
147       // We currently do not represent a truncate expression as an affine
148       // expression. If it is constant during Scop execution, we treat it as a
149       // parameter.
150       return ValidatorResult(SCEVType::PARAM, Expr);
151     case SCEVType::IV:
152       DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression");
153       return ValidatorResult(SCEVType::INVALID);
154     case SCEVType::INVALID:
155       return Op;
156     }
157 
158     llvm_unreachable("Unknown SCEVType");
159   }
160 
161   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
162     ValidatorResult Op = visit(Expr->getOperand());
163 
164     switch (Op.getType()) {
165     case SCEVType::INT:
166     case SCEVType::PARAM:
167       // We currently do not represent a truncate expression as an affine
168       // expression. If it is constant during Scop execution, we treat it as a
169       // parameter.
170       return ValidatorResult(SCEVType::PARAM, Expr);
171     case SCEVType::IV:
172       DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression");
173       return ValidatorResult(SCEVType::INVALID);
174     case SCEVType::INVALID:
175       return Op;
176     }
177 
178     llvm_unreachable("Unknown SCEVType");
179   }
180 
181   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
182     // We currently allow only signed SCEV expressions. In the case of a
183     // signed value, a sign extend is a noop.
184     //
185     // TODO: Reconsider this when we add support for unsigned values.
186     return visit(Expr->getOperand());
187   }
188 
189   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
190     ValidatorResult Return(SCEVType::INT);
191 
192     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
193       ValidatorResult Op = visit(Expr->getOperand(i));
194       Return.merge(Op);
195 
196       // Early exit.
197       if (!Return.isValid())
198         break;
199     }
200 
201     // TODO: Check for NSW and NUW.
202     return Return;
203   }
204 
205   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
206     ValidatorResult Return(SCEVType::INT);
207 
208     bool HasMultipleParams = false;
209 
210     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
211       ValidatorResult Op = visit(Expr->getOperand(i));
212 
213       if (Op.isINT())
214         continue;
215 
216       if (Op.isPARAM() && Return.isPARAM()) {
217         HasMultipleParams = true;
218         continue;
219       }
220 
221       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
222         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
223                      << "\tExpr: " << *Expr << "\n"
224                      << "\tPrevious expression type: " << Return << "\n"
225                      << "\tNext operand (" << Op
226                      << "): " << *Expr->getOperand(i) << "\n");
227 
228         return ValidatorResult(SCEVType::INVALID);
229       }
230 
231       Return.merge(Op);
232     }
233 
234     if (HasMultipleParams && Return.isValid())
235       return ValidatorResult(SCEVType::PARAM, Expr);
236 
237     // TODO: Check for NSW and NUW.
238     return Return;
239   }
240 
241   class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
242     ValidatorResult LHS = visit(Expr->getLHS());
243     ValidatorResult RHS = visit(Expr->getRHS());
244 
245     // We currently do not represent an unsigned division as an affine
246     // expression. If the division is constant during Scop execution we treat it
247     // as a parameter, otherwise we bail out.
248     if (LHS.isConstant() && RHS.isConstant())
249       return ValidatorResult(SCEVType::PARAM, Expr);
250 
251     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
252     return ValidatorResult(SCEVType::INVALID);
253   }
254 
255   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
256     if (!Expr->isAffine()) {
257       DEBUG(dbgs() << "INVALID: AddRec is not affine");
258       return ValidatorResult(SCEVType::INVALID);
259     }
260 
261     ValidatorResult Start = visit(Expr->getStart());
262     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
263 
264     if (!Start.isValid())
265       return Start;
266 
267     if (!Recurrence.isValid())
268       return Recurrence;
269 
270     if (R->contains(Expr->getLoop())) {
271       if (Recurrence.isINT()) {
272         ValidatorResult Result(SCEVType::IV);
273         Result.addParamsFrom(Start);
274         return Result;
275       }
276 
277       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
278                       "recurrence part");
279       return ValidatorResult(SCEVType::INVALID);
280     }
281 
282     assert(Start.isConstant() && Recurrence.isConstant() &&
283            "Expected 'Start' and 'Recurrence' to be constant");
284 
285     // Directly generate ValidatorResult for Expr if 'start' is zero.
286     if (Expr->getStart()->isZero())
287       return ValidatorResult(SCEVType::PARAM, Expr);
288 
289     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
290     // if 'start' is not zero.
291     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
292         SE.getConstant(Expr->getStart()->getType(), 0),
293         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
294 
295     ValidatorResult ZeroStartResult =
296         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
297     ZeroStartResult.addParamsFrom(Start);
298 
299     return ZeroStartResult;
300   }
301 
302   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
303     ValidatorResult Return(SCEVType::INT);
304 
305     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
306       ValidatorResult Op = visit(Expr->getOperand(i));
307 
308       if (!Op.isValid())
309         return Op;
310 
311       Return.merge(Op);
312     }
313 
314     return Return;
315   }
316 
317   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
318     // We do not support unsigned operations. If 'Expr' is constant during Scop
319     // execution we treat this as a parameter, otherwise we bail out.
320     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
321       ValidatorResult Op = visit(Expr->getOperand(i));
322 
323       if (!Op.isConstant()) {
324         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
325         return ValidatorResult(SCEVType::INVALID);
326       }
327     }
328 
329     return ValidatorResult(SCEVType::PARAM, Expr);
330   }
331 
332   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
333     if (R->contains(I)) {
334       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
335                       "within the region\n");
336       return ValidatorResult(SCEVType::INVALID);
337     }
338 
339     return ValidatorResult(SCEVType::PARAM, S);
340   }
341 
342   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
343     if (R->contains(I) && ILS) {
344       ILS->insert(cast<LoadInst>(I));
345       return ValidatorResult(SCEVType::PARAM, S);
346     }
347 
348     return visitGenericInst(I, S);
349   }
350 
351   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) {
352     assert(SDiv->getOpcode() == Instruction::SDiv &&
353            "Assumed SDiv instruction!");
354 
355     auto *Divisor = SDiv->getOperand(1);
356     auto *CI = dyn_cast<ConstantInt>(Divisor);
357     if (!CI)
358       return visitGenericInst(SDiv, S);
359 
360     auto *Dividend = SDiv->getOperand(0);
361     auto *DividendSCEV = SE.getSCEV(Dividend);
362     return visit(DividendSCEV);
363   }
364 
365   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
366     assert(SRem->getOpcode() == Instruction::SRem &&
367            "Assumed SRem instruction!");
368 
369     auto *Divisor = SRem->getOperand(1);
370     auto *CI = dyn_cast<ConstantInt>(Divisor);
371     if (!CI)
372       return visitGenericInst(SRem, S);
373 
374     auto *Dividend = SRem->getOperand(0);
375     auto *DividendSCEV = SE.getSCEV(Dividend);
376     return visit(DividendSCEV);
377   }
378 
379   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
380     Value *V = Expr->getValue();
381 
382     // TODO: FIXME: IslExprBuilder is not capable of producing valid code
383     //              for arbitrary pointer expressions at the moment. Until
384     //              this is fixed we disallow pointer expressions completely.
385     if (Expr->getType()->isPointerTy()) {
386       DEBUG(dbgs() << "INVALID: UnknownExpr is a pointer type [FIXME]");
387       return ValidatorResult(SCEVType::INVALID);
388     }
389 
390     if (!Expr->getType()->isIntegerTy()) {
391       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer");
392       return ValidatorResult(SCEVType::INVALID);
393     }
394 
395     if (isa<UndefValue>(V)) {
396       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
397       return ValidatorResult(SCEVType::INVALID);
398     }
399 
400     if (BaseAddress == V) {
401       DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n");
402       return ValidatorResult(SCEVType::INVALID);
403     }
404 
405     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
406       switch (I->getOpcode()) {
407       case Instruction::Load:
408         return visitLoadInstruction(I, Expr);
409       case Instruction::SDiv:
410         return visitSDivInstruction(I, Expr);
411       case Instruction::SRem:
412         return visitSRemInstruction(I, Expr);
413       default:
414         return visitGenericInst(I, Expr);
415       }
416     }
417 
418     return ValidatorResult(SCEVType::PARAM, Expr);
419   }
420 };
421 
422 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
423 ///
424 struct SCEVInRegionDependences
425     : public SCEVVisitor<SCEVInRegionDependences, bool> {
426 public:
427   /// Returns true when the SCEV has SSA names defined in region R. It @p
428   /// AllowLoops is false, loop dependences are checked as well. AddRec SCEVs
429   /// are only allowed within its loop (current loop determined by @p Scope),
430   /// not outside of it unless AddRec's loop is not even in the region.
431   static bool hasDependences(const SCEV *S, const Region *R, Loop *Scope,
432                              bool AllowLoops) {
433     SCEVInRegionDependences Ignore(R, Scope, AllowLoops);
434     return Ignore.visit(S);
435   }
436 
437   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops)
438       : R(R), Scope(Scope), AllowLoops(AllowLoops) {}
439 
440   bool visit(const SCEV *Expr) {
441     return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr);
442   }
443 
444   bool visitConstant(const SCEVConstant *Constant) { return false; }
445 
446   bool visitTruncateExpr(const SCEVTruncateExpr *Expr) {
447     return visit(Expr->getOperand());
448   }
449 
450   bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
451     return visit(Expr->getOperand());
452   }
453 
454   bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
455     return visit(Expr->getOperand());
456   }
457 
458   bool visitAddExpr(const SCEVAddExpr *Expr) {
459     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
460       if (visit(Expr->getOperand(i)))
461         return true;
462 
463     return false;
464   }
465 
466   bool visitMulExpr(const SCEVMulExpr *Expr) {
467     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
468       if (visit(Expr->getOperand(i)))
469         return true;
470 
471     return false;
472   }
473 
474   bool visitUDivExpr(const SCEVUDivExpr *Expr) {
475     if (visit(Expr->getLHS()))
476       return true;
477 
478     if (visit(Expr->getRHS()))
479       return true;
480 
481     return false;
482   }
483 
484   bool visitAddRecExpr(const SCEVAddRecExpr *Expr) {
485     if (!AllowLoops) {
486       if (!Scope)
487         return true;
488       auto *L = Expr->getLoop();
489       if (R->contains(L) && !L->contains(Scope))
490         return true;
491     }
492 
493     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
494       if (visit(Expr->getOperand(i)))
495         return true;
496 
497     return false;
498   }
499 
500   bool visitSMaxExpr(const SCEVSMaxExpr *Expr) {
501     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
502       if (visit(Expr->getOperand(i)))
503         return true;
504 
505     return false;
506   }
507 
508   bool visitUMaxExpr(const SCEVUMaxExpr *Expr) {
509     for (size_t i = 0; i < Expr->getNumOperands(); ++i)
510       if (visit(Expr->getOperand(i)))
511         return true;
512 
513     return false;
514   }
515 
516   bool visitUnknown(const SCEVUnknown *Expr) {
517     Instruction *Inst = dyn_cast<Instruction>(Expr->getValue());
518 
519     // Return true when Inst is defined inside the region R.
520     if (Inst && R->contains(Inst))
521       return true;
522 
523     return false;
524   }
525 
526 private:
527   const Region *R;
528   Loop *Scope;
529   bool AllowLoops;
530 };
531 
532 namespace polly {
533 /// Find all loops referenced in SCEVAddRecExprs.
534 class SCEVFindLoops {
535   SetVector<const Loop *> &Loops;
536 
537 public:
538   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
539 
540   bool follow(const SCEV *S) {
541     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
542       Loops.insert(AddRec->getLoop());
543     return true;
544   }
545   bool isDone() { return false; }
546 };
547 
548 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
549   SCEVFindLoops FindLoops(Loops);
550   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
551   ST.visitAll(Expr);
552 }
553 
554 /// Find all values referenced in SCEVUnknowns.
555 class SCEVFindValues {
556   SetVector<Value *> &Values;
557 
558 public:
559   SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {}
560 
561   bool follow(const SCEV *S) {
562     if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S))
563       Values.insert(Unknown->getValue());
564     return true;
565   }
566   bool isDone() { return false; }
567 };
568 
569 void findValues(const SCEV *Expr, SetVector<Value *> &Values) {
570   SCEVFindValues FindValues(Values);
571   SCEVTraversal<SCEVFindValues> ST(FindValues);
572   ST.visitAll(Expr);
573 }
574 
575 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
576                                llvm::Loop *Scope, bool AllowLoops) {
577   return SCEVInRegionDependences::hasDependences(Expr, R, Scope, AllowLoops);
578 }
579 
580 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
581                   ScalarEvolution &SE, const Value *BaseAddress,
582                   InvariantLoadsSetTy *ILS) {
583   if (isa<SCEVCouldNotCompute>(Expr))
584     return false;
585 
586   SCEVValidator Validator(R, Scope, SE, BaseAddress, ILS);
587   DEBUG({
588     dbgs() << "\n";
589     dbgs() << "Expr: " << *Expr << "\n";
590     dbgs() << "Region: " << R->getNameStr() << "\n";
591     dbgs() << " -> ";
592   });
593 
594   ValidatorResult Result = Validator.visit(Expr);
595 
596   DEBUG({
597     if (Result.isValid())
598       dbgs() << "VALID\n";
599     dbgs() << "\n";
600   });
601 
602   return Result.isValid();
603 }
604 
605 static bool isAffineParamExpr(Value *V, const Region *R, Loop *Scope,
606                               ScalarEvolution &SE,
607                               std::vector<const SCEV *> &Params) {
608   auto *E = SE.getSCEV(V);
609   if (isa<SCEVCouldNotCompute>(E))
610     return false;
611 
612   SCEVValidator Validator(R, Scope, SE, nullptr, nullptr);
613   ValidatorResult Result = Validator.visit(E);
614   if (!Result.isConstant())
615     return false;
616 
617   auto ResultParams = Result.getParameters();
618   Params.insert(Params.end(), ResultParams.begin(), ResultParams.end());
619 
620   return true;
621 }
622 
623 bool isAffineParamConstraint(Value *V, const Region *R, llvm::Loop *Scope,
624                              ScalarEvolution &SE,
625                              std::vector<const SCEV *> &Params, bool OrExpr) {
626   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
627     return isAffineParamConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
628                                    true) &&
629            isAffineParamConstraint(ICmp->getOperand(1), R, Scope, SE, Params,
630                                    true);
631   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
632     auto Opcode = BinOp->getOpcode();
633     if (Opcode == Instruction::And || Opcode == Instruction::Or)
634       return isAffineParamConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
635                                      false) &&
636              isAffineParamConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
637                                      false);
638     /* Fall through */
639   }
640 
641   if (!OrExpr)
642     return false;
643 
644   return isAffineParamExpr(V, R, Scope, SE, Params);
645 }
646 
647 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R, Loop *Scope,
648                                                 const SCEV *Expr,
649                                                 ScalarEvolution &SE,
650                                                 const Value *BaseAddress) {
651   if (isa<SCEVCouldNotCompute>(Expr))
652     return std::vector<const SCEV *>();
653 
654   InvariantLoadsSetTy ILS;
655   SCEVValidator Validator(R, Scope, SE, BaseAddress, &ILS);
656   ValidatorResult Result = Validator.visit(Expr);
657   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
658 
659   return Result.getParameters();
660 }
661 
662 std::pair<const SCEVConstant *, const SCEV *>
663 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
664 
665   auto *LeftOver = SE.getConstant(S->getType(), 1);
666   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
667 
668   if (auto *Constant = dyn_cast<SCEVConstant>(S))
669     return std::make_pair(Constant, LeftOver);
670 
671   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
672   if (AddRec) {
673     auto *StartExpr = AddRec->getStart();
674     if (StartExpr->isZero()) {
675       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
676       auto *LeftOverAddRec =
677           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
678                            AddRec->getNoWrapFlags());
679       return std::make_pair(StepPair.first, LeftOverAddRec);
680     }
681     return std::make_pair(ConstPart, S);
682   }
683 
684   auto *Mul = dyn_cast<SCEVMulExpr>(S);
685   if (!Mul)
686     return std::make_pair(ConstPart, S);
687 
688   for (auto *Op : Mul->operands())
689     if (isa<SCEVConstant>(Op))
690       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
691     else
692       LeftOver = SE.getMulExpr(LeftOver, Op);
693 
694   return std::make_pair(ConstPart, LeftOver);
695 }
696 }
697