xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 127e0cd21b6e89fe9f052d39aaeea35c64e9b3da)
1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 } // namespace SCEVType
36 
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// Get the parameters of this validator result.
81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
120 bool polly::isConstCall(llvm::CallInst *Call) {
121   if (Call->mayReadOrWriteMemory())
122     return false;
123 
124   for (auto &Operand : Call->arg_operands())
125     if (!isa<ConstantInt>(&Operand))
126       return false;
127 
128   return true;
129 }
130 
131 /// Check if a SCEV is valid in a SCoP.
132 struct SCEVValidator
133     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
134 private:
135   const Region *R;
136   Loop *Scope;
137   ScalarEvolution &SE;
138   InvariantLoadsSetTy *ILS;
139 
140 public:
141   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
142                 InvariantLoadsSetTy *ILS)
143       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
144 
145   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
146     return ValidatorResult(SCEVType::INT);
147   }
148 
149   class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
150                                                       const SCEV *Operand) {
151     ValidatorResult Op = visit(Operand);
152     auto Type = Op.getType();
153 
154     // If unsigned operations are allowed return the operand, otherwise
155     // check if we can model the expression without unsigned assumptions.
156     if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
157       return Op;
158 
159     if (Type == SCEVType::IV)
160       return ValidatorResult(SCEVType::INVALID);
161     return ValidatorResult(SCEVType::PARAM, Expr);
162   }
163 
164   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
165     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
166   }
167 
168   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
169     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
170   }
171 
172   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
173     return visit(Expr->getOperand());
174   }
175 
176   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
177     ValidatorResult Return(SCEVType::INT);
178 
179     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
180       ValidatorResult Op = visit(Expr->getOperand(i));
181       Return.merge(Op);
182 
183       // Early exit.
184       if (!Return.isValid())
185         break;
186     }
187 
188     return Return;
189   }
190 
191   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
192     ValidatorResult Return(SCEVType::INT);
193 
194     bool HasMultipleParams = false;
195 
196     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
197       ValidatorResult Op = visit(Expr->getOperand(i));
198 
199       if (Op.isINT())
200         continue;
201 
202       if (Op.isPARAM() && Return.isPARAM()) {
203         HasMultipleParams = true;
204         continue;
205       }
206 
207       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
208         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
209                      << "\tExpr: " << *Expr << "\n"
210                      << "\tPrevious expression type: " << Return << "\n"
211                      << "\tNext operand (" << Op
212                      << "): " << *Expr->getOperand(i) << "\n");
213 
214         return ValidatorResult(SCEVType::INVALID);
215       }
216 
217       Return.merge(Op);
218     }
219 
220     if (HasMultipleParams && Return.isValid())
221       return ValidatorResult(SCEVType::PARAM, Expr);
222 
223     return Return;
224   }
225 
226   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
227     if (!Expr->isAffine()) {
228       DEBUG(dbgs() << "INVALID: AddRec is not affine");
229       return ValidatorResult(SCEVType::INVALID);
230     }
231 
232     ValidatorResult Start = visit(Expr->getStart());
233     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
234 
235     if (!Start.isValid())
236       return Start;
237 
238     if (!Recurrence.isValid())
239       return Recurrence;
240 
241     auto *L = Expr->getLoop();
242     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
243       DEBUG(dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
244                       "non-affine subregion or has a non-synthesizable exit "
245                       "value.");
246       return ValidatorResult(SCEVType::INVALID);
247     }
248 
249     if (R->contains(L)) {
250       if (Recurrence.isINT()) {
251         ValidatorResult Result(SCEVType::IV);
252         Result.addParamsFrom(Start);
253         return Result;
254       }
255 
256       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
257                       "recurrence part");
258       return ValidatorResult(SCEVType::INVALID);
259     }
260 
261     assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
262 
263     // Directly generate ValidatorResult for Expr if 'start' is zero.
264     if (Expr->getStart()->isZero())
265       return ValidatorResult(SCEVType::PARAM, Expr);
266 
267     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
268     // if 'start' is not zero.
269     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
270         SE.getConstant(Expr->getStart()->getType(), 0),
271         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
272 
273     ValidatorResult ZeroStartResult =
274         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
275     ZeroStartResult.addParamsFrom(Start);
276 
277     return ZeroStartResult;
278   }
279 
280   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
281     ValidatorResult Return(SCEVType::INT);
282 
283     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
284       ValidatorResult Op = visit(Expr->getOperand(i));
285 
286       if (!Op.isValid())
287         return Op;
288 
289       Return.merge(Op);
290     }
291 
292     return Return;
293   }
294 
295   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
296     // We do not support unsigned max operations. If 'Expr' is constant during
297     // Scop execution we treat this as a parameter, otherwise we bail out.
298     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
299       ValidatorResult Op = visit(Expr->getOperand(i));
300 
301       if (!Op.isConstant()) {
302         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
303         return ValidatorResult(SCEVType::INVALID);
304       }
305     }
306 
307     return ValidatorResult(SCEVType::PARAM, Expr);
308   }
309 
310   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
311     if (R->contains(I)) {
312       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
313                       "within the region\n");
314       return ValidatorResult(SCEVType::INVALID);
315     }
316 
317     return ValidatorResult(SCEVType::PARAM, S);
318   }
319 
320   ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
321     assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
322 
323     if (R->contains(I)) {
324       auto Call = cast<CallInst>(I);
325 
326       if (!isConstCall(Call))
327         return ValidatorResult(SCEVType::INVALID, S);
328     }
329     return ValidatorResult(SCEVType::PARAM, S);
330   }
331 
332   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
333     if (R->contains(I) && ILS) {
334       ILS->insert(cast<LoadInst>(I));
335       return ValidatorResult(SCEVType::PARAM, S);
336     }
337 
338     return visitGenericInst(I, S);
339   }
340 
341   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
342                                 const SCEV *DivExpr,
343                                 Instruction *SDiv = nullptr) {
344 
345     // First check if we might be able to model the division, thus if the
346     // divisor is constant. If so, check the dividend, otherwise check if
347     // the whole division can be seen as a parameter.
348     if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
349       return visit(Dividend);
350 
351     // For signed divisions use the SDiv instruction to check for a parameter
352     // division, for unsigned divisions check the operands.
353     if (SDiv)
354       return visitGenericInst(SDiv, DivExpr);
355 
356     ValidatorResult LHS = visit(Dividend);
357     ValidatorResult RHS = visit(Divisor);
358     if (LHS.isConstant() && RHS.isConstant())
359       return ValidatorResult(SCEVType::PARAM, DivExpr);
360 
361     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
362     return ValidatorResult(SCEVType::INVALID);
363   }
364 
365   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
366     if (!PollyAllowUnsignedOperations)
367       return ValidatorResult(SCEVType::INVALID);
368 
369     auto *Dividend = Expr->getLHS();
370     auto *Divisor = Expr->getRHS();
371     return visitDivision(Dividend, Divisor, Expr);
372   }
373 
374   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
375     assert(SDiv->getOpcode() == Instruction::SDiv &&
376            "Assumed SDiv instruction!");
377 
378     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
379     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
380     return visitDivision(Dividend, Divisor, Expr, SDiv);
381   }
382 
383   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
384     assert(SRem->getOpcode() == Instruction::SRem &&
385            "Assumed SRem instruction!");
386 
387     auto *Divisor = SRem->getOperand(1);
388     auto *CI = dyn_cast<ConstantInt>(Divisor);
389     if (!CI || CI->isZeroValue())
390       return visitGenericInst(SRem, S);
391 
392     auto *Dividend = SRem->getOperand(0);
393     auto *DividendSCEV = SE.getSCEV(Dividend);
394     return visit(DividendSCEV);
395   }
396 
397   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
398     Value *V = Expr->getValue();
399 
400     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
401       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
402       return ValidatorResult(SCEVType::INVALID);
403     }
404 
405     if (isa<UndefValue>(V)) {
406       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
407       return ValidatorResult(SCEVType::INVALID);
408     }
409 
410     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
411       switch (I->getOpcode()) {
412       case Instruction::IntToPtr:
413         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
414       case Instruction::PtrToInt:
415         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
416       case Instruction::Load:
417         return visitLoadInstruction(I, Expr);
418       case Instruction::SDiv:
419         return visitSDivInstruction(I, Expr);
420       case Instruction::SRem:
421         return visitSRemInstruction(I, Expr);
422       case Instruction::Call:
423         return visitCallInstruction(I, Expr);
424       default:
425         return visitGenericInst(I, Expr);
426       }
427     }
428 
429     return ValidatorResult(SCEVType::PARAM, Expr);
430   }
431 };
432 
433 class SCEVHasIVParams {
434   bool HasIVParams = false;
435 
436 public:
437   SCEVHasIVParams() {}
438 
439   bool follow(const SCEV *S) {
440     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
441     if (!Unknown)
442       return true;
443 
444     CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
445 
446     if (!Call)
447       return true;
448 
449     if (isConstCall(Call)) {
450       HasIVParams = true;
451       return false;
452     }
453 
454     return true;
455   }
456 
457   bool isDone() { return HasIVParams; }
458   bool hasIVParams() { return HasIVParams; }
459 };
460 
461 /// Check whether a SCEV refers to an SSA name defined inside a region.
462 class SCEVInRegionDependences {
463   const Region *R;
464   Loop *Scope;
465   bool AllowLoops;
466   bool HasInRegionDeps = false;
467 
468 public:
469   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops)
470       : R(R), Scope(Scope), AllowLoops(AllowLoops) {}
471 
472   bool follow(const SCEV *S) {
473     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
474       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
475 
476       CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
477 
478       if (Call && isConstCall(Call))
479         return false;
480 
481       // Return true when Inst is defined inside the region R.
482       if (!Inst || !R->contains(Inst))
483         return true;
484 
485       HasInRegionDeps = true;
486       return false;
487     }
488 
489     if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
490       if (AllowLoops)
491         return true;
492 
493       if (!Scope) {
494         HasInRegionDeps = true;
495         return false;
496       }
497       auto *L = AddRec->getLoop();
498       if (R->contains(L) && !L->contains(Scope)) {
499         HasInRegionDeps = true;
500         return false;
501       }
502     }
503 
504     return true;
505   }
506   bool isDone() { return false; }
507   bool hasDependences() { return HasInRegionDeps; }
508 };
509 
510 namespace polly {
511 /// Find all loops referenced in SCEVAddRecExprs.
512 class SCEVFindLoops {
513   SetVector<const Loop *> &Loops;
514 
515 public:
516   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
517 
518   bool follow(const SCEV *S) {
519     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
520       Loops.insert(AddRec->getLoop());
521     return true;
522   }
523   bool isDone() { return false; }
524 };
525 
526 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
527   SCEVFindLoops FindLoops(Loops);
528   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
529   ST.visitAll(Expr);
530 }
531 
532 /// Find all values referenced in SCEVUnknowns.
533 class SCEVFindValues {
534   ScalarEvolution &SE;
535   SetVector<Value *> &Values;
536 
537 public:
538   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
539       : SE(SE), Values(Values) {}
540 
541   bool follow(const SCEV *S) {
542     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
543     if (!Unknown)
544       return true;
545 
546     Values.insert(Unknown->getValue());
547     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
548     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
549                   Inst->getOpcode() != Instruction::SDiv))
550       return false;
551 
552     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
553     if (!isa<SCEVConstant>(Dividend))
554       return false;
555 
556     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
557     SCEVFindValues FindValues(SE, Values);
558     SCEVTraversal<SCEVFindValues> ST(FindValues);
559     ST.visitAll(Dividend);
560     ST.visitAll(Divisor);
561 
562     return false;
563   }
564   bool isDone() { return false; }
565 };
566 
567 void findValues(const SCEV *Expr, ScalarEvolution &SE,
568                 SetVector<Value *> &Values) {
569   SCEVFindValues FindValues(SE, Values);
570   SCEVTraversal<SCEVFindValues> ST(FindValues);
571   ST.visitAll(Expr);
572 }
573 
574 bool hasIVParams(const SCEV *Expr) {
575   SCEVHasIVParams HasIVParams;
576   SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
577   ST.visitAll(Expr);
578   return HasIVParams.hasIVParams();
579 }
580 
581 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
582                                llvm::Loop *Scope, bool AllowLoops) {
583   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops);
584   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
585   ST.visitAll(Expr);
586   return InRegionDeps.hasDependences();
587 }
588 
589 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
590                   ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
591   if (isa<SCEVCouldNotCompute>(Expr))
592     return false;
593 
594   SCEVValidator Validator(R, Scope, SE, ILS);
595   DEBUG({
596     dbgs() << "\n";
597     dbgs() << "Expr: " << *Expr << "\n";
598     dbgs() << "Region: " << R->getNameStr() << "\n";
599     dbgs() << " -> ";
600   });
601 
602   ValidatorResult Result = Validator.visit(Expr);
603 
604   DEBUG({
605     if (Result.isValid())
606       dbgs() << "VALID\n";
607     dbgs() << "\n";
608   });
609 
610   return Result.isValid();
611 }
612 
613 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
614                          ScalarEvolution &SE, ParameterSetTy &Params) {
615   auto *E = SE.getSCEV(V);
616   if (isa<SCEVCouldNotCompute>(E))
617     return false;
618 
619   SCEVValidator Validator(R, Scope, SE, nullptr);
620   ValidatorResult Result = Validator.visit(E);
621   if (!Result.isValid())
622     return false;
623 
624   auto ResultParams = Result.getParameters();
625   Params.insert(ResultParams.begin(), ResultParams.end());
626 
627   return true;
628 }
629 
630 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
631                         ScalarEvolution &SE, ParameterSetTy &Params,
632                         bool OrExpr) {
633   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
634     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
635                               true) &&
636            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
637   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
638     auto Opcode = BinOp->getOpcode();
639     if (Opcode == Instruction::And || Opcode == Instruction::Or)
640       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
641                                 false) &&
642              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
643                                 false);
644     /* Fall through */
645   }
646 
647   if (!OrExpr)
648     return false;
649 
650   return isAffineExpr(V, R, Scope, SE, Params);
651 }
652 
653 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
654                                      const SCEV *Expr, ScalarEvolution &SE) {
655   if (isa<SCEVCouldNotCompute>(Expr))
656     return ParameterSetTy();
657 
658   InvariantLoadsSetTy ILS;
659   SCEVValidator Validator(R, Scope, SE, &ILS);
660   ValidatorResult Result = Validator.visit(Expr);
661   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
662 
663   return Result.getParameters();
664 }
665 
666 std::pair<const SCEVConstant *, const SCEV *>
667 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
668   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
669 
670   if (auto *Constant = dyn_cast<SCEVConstant>(S))
671     return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
672 
673   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
674   if (AddRec) {
675     auto *StartExpr = AddRec->getStart();
676     if (StartExpr->isZero()) {
677       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
678       auto *LeftOverAddRec =
679           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
680                            AddRec->getNoWrapFlags());
681       return std::make_pair(StepPair.first, LeftOverAddRec);
682     }
683     return std::make_pair(ConstPart, S);
684   }
685 
686   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
687     SmallVector<const SCEV *, 4> LeftOvers;
688     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
689     auto *Factor = Op0Pair.first;
690     if (SE.isKnownNegative(Factor)) {
691       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
692       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
693     } else {
694       LeftOvers.push_back(Op0Pair.second);
695     }
696 
697     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
698       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
699       // TODO: Use something smarter than equality here, e.g., gcd.
700       if (Factor == OpUPair.first)
701         LeftOvers.push_back(OpUPair.second);
702       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
703         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
704       else
705         return std::make_pair(ConstPart, S);
706     }
707 
708     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
709     return std::make_pair(Factor, NewAdd);
710   }
711 
712   auto *Mul = dyn_cast<SCEVMulExpr>(S);
713   if (!Mul)
714     return std::make_pair(ConstPart, S);
715 
716   SmallVector<const SCEV *, 4> LeftOvers;
717   for (auto *Op : Mul->operands())
718     if (isa<SCEVConstant>(Op))
719       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
720     else
721       LeftOvers.push_back(Op);
722 
723   return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
724 }
725 } // namespace polly
726