1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 9 using namespace llvm; 10 using namespace polly; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } // namespace SCEVType 36 37 /// The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// The type of the expression 40 SCEVType::TYPE Type; 41 42 /// The set of Parameters in the expression. 43 ParameterSetTy Parameters; 44 45 public: 46 /// The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.insert(Expr); 60 } 61 62 /// Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// Get the parameters of this validator result. 81 const ParameterSetTy &getParameters() { return Parameters; } 82 83 /// Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end()); 86 } 87 88 /// Merge a result. 89 /// 90 /// This means to merge the parameters and to set the Type to the most 91 /// specific Type that matches both. 92 void merge(const ValidatorResult &ToMerge) { 93 Type = std::max(Type, ToMerge.Type); 94 addParamsFrom(ToMerge); 95 } 96 97 void print(raw_ostream &OS) { 98 switch (Type) { 99 case SCEVType::INT: 100 OS << "SCEVType::INT"; 101 break; 102 case SCEVType::PARAM: 103 OS << "SCEVType::PARAM"; 104 break; 105 case SCEVType::IV: 106 OS << "SCEVType::IV"; 107 break; 108 case SCEVType::INVALID: 109 OS << "SCEVType::INVALID"; 110 break; 111 } 112 } 113 }; 114 115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 116 VR.print(OS); 117 return OS; 118 } 119 120 bool polly::isConstCall(llvm::CallInst *Call) { 121 if (Call->mayReadOrWriteMemory()) 122 return false; 123 124 for (auto &Operand : Call->arg_operands()) 125 if (!isa<ConstantInt>(&Operand)) 126 return false; 127 128 return true; 129 } 130 131 /// Check if a SCEV is valid in a SCoP. 132 struct SCEVValidator 133 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 134 private: 135 const Region *R; 136 Loop *Scope; 137 ScalarEvolution &SE; 138 InvariantLoadsSetTy *ILS; 139 140 public: 141 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE, 142 InvariantLoadsSetTy *ILS) 143 : R(R), Scope(Scope), SE(SE), ILS(ILS) {} 144 145 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 146 return ValidatorResult(SCEVType::INT); 147 } 148 149 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr, 150 const SCEV *Operand) { 151 ValidatorResult Op = visit(Operand); 152 auto Type = Op.getType(); 153 154 // If unsigned operations are allowed return the operand, otherwise 155 // check if we can model the expression without unsigned assumptions. 156 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID) 157 return Op; 158 159 if (Type == SCEVType::IV) 160 return ValidatorResult(SCEVType::INVALID); 161 return ValidatorResult(SCEVType::PARAM, Expr); 162 } 163 164 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 165 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 166 } 167 168 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 169 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 170 } 171 172 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 173 return visit(Expr->getOperand()); 174 } 175 176 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 177 ValidatorResult Return(SCEVType::INT); 178 179 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 180 ValidatorResult Op = visit(Expr->getOperand(i)); 181 Return.merge(Op); 182 183 // Early exit. 184 if (!Return.isValid()) 185 break; 186 } 187 188 return Return; 189 } 190 191 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 192 ValidatorResult Return(SCEVType::INT); 193 194 bool HasMultipleParams = false; 195 196 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 197 ValidatorResult Op = visit(Expr->getOperand(i)); 198 199 if (Op.isINT()) 200 continue; 201 202 if (Op.isPARAM() && Return.isPARAM()) { 203 HasMultipleParams = true; 204 continue; 205 } 206 207 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 208 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 209 << "\tExpr: " << *Expr << "\n" 210 << "\tPrevious expression type: " << Return << "\n" 211 << "\tNext operand (" << Op 212 << "): " << *Expr->getOperand(i) << "\n"); 213 214 return ValidatorResult(SCEVType::INVALID); 215 } 216 217 Return.merge(Op); 218 } 219 220 if (HasMultipleParams && Return.isValid()) 221 return ValidatorResult(SCEVType::PARAM, Expr); 222 223 return Return; 224 } 225 226 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 227 if (!Expr->isAffine()) { 228 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 229 return ValidatorResult(SCEVType::INVALID); 230 } 231 232 ValidatorResult Start = visit(Expr->getStart()); 233 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 234 235 if (!Start.isValid()) 236 return Start; 237 238 if (!Recurrence.isValid()) 239 return Recurrence; 240 241 auto *L = Expr->getLoop(); 242 if (R->contains(L) && (!Scope || !L->contains(Scope))) { 243 DEBUG(dbgs() << "INVALID: Loop of AddRec expression boxed in an a " 244 "non-affine subregion or has a non-synthesizable exit " 245 "value."); 246 return ValidatorResult(SCEVType::INVALID); 247 } 248 249 if (R->contains(L)) { 250 if (Recurrence.isINT()) { 251 ValidatorResult Result(SCEVType::IV); 252 Result.addParamsFrom(Start); 253 return Result; 254 } 255 256 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 257 "recurrence part"); 258 return ValidatorResult(SCEVType::INVALID); 259 } 260 261 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant"); 262 263 // Directly generate ValidatorResult for Expr if 'start' is zero. 264 if (Expr->getStart()->isZero()) 265 return ValidatorResult(SCEVType::PARAM, Expr); 266 267 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 268 // if 'start' is not zero. 269 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 270 SE.getConstant(Expr->getStart()->getType(), 0), 271 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 272 273 ValidatorResult ZeroStartResult = 274 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 275 ZeroStartResult.addParamsFrom(Start); 276 277 return ZeroStartResult; 278 } 279 280 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 281 ValidatorResult Return(SCEVType::INT); 282 283 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 284 ValidatorResult Op = visit(Expr->getOperand(i)); 285 286 if (!Op.isValid()) 287 return Op; 288 289 Return.merge(Op); 290 } 291 292 return Return; 293 } 294 295 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 296 // We do not support unsigned max operations. If 'Expr' is constant during 297 // Scop execution we treat this as a parameter, otherwise we bail out. 298 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 299 ValidatorResult Op = visit(Expr->getOperand(i)); 300 301 if (!Op.isConstant()) { 302 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 303 return ValidatorResult(SCEVType::INVALID); 304 } 305 } 306 307 return ValidatorResult(SCEVType::PARAM, Expr); 308 } 309 310 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 311 if (R->contains(I)) { 312 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 313 "within the region\n"); 314 return ValidatorResult(SCEVType::INVALID); 315 } 316 317 return ValidatorResult(SCEVType::PARAM, S); 318 } 319 320 ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) { 321 assert(I->getOpcode() == Instruction::Call && "Call instruction expected"); 322 323 if (R->contains(I)) { 324 auto Call = cast<CallInst>(I); 325 326 if (!isConstCall(Call)) 327 return ValidatorResult(SCEVType::INVALID, S); 328 } 329 return ValidatorResult(SCEVType::PARAM, S); 330 } 331 332 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) { 333 if (R->contains(I) && ILS) { 334 ILS->insert(cast<LoadInst>(I)); 335 return ValidatorResult(SCEVType::PARAM, S); 336 } 337 338 return visitGenericInst(I, S); 339 } 340 341 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor, 342 const SCEV *DivExpr, 343 Instruction *SDiv = nullptr) { 344 345 // First check if we might be able to model the division, thus if the 346 // divisor is constant. If so, check the dividend, otherwise check if 347 // the whole division can be seen as a parameter. 348 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero()) 349 return visit(Dividend); 350 351 // For signed divisions use the SDiv instruction to check for a parameter 352 // division, for unsigned divisions check the operands. 353 if (SDiv) 354 return visitGenericInst(SDiv, DivExpr); 355 356 ValidatorResult LHS = visit(Dividend); 357 ValidatorResult RHS = visit(Divisor); 358 if (LHS.isConstant() && RHS.isConstant()) 359 return ValidatorResult(SCEVType::PARAM, DivExpr); 360 361 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 362 return ValidatorResult(SCEVType::INVALID); 363 } 364 365 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 366 if (!PollyAllowUnsignedOperations) 367 return ValidatorResult(SCEVType::INVALID); 368 369 auto *Dividend = Expr->getLHS(); 370 auto *Divisor = Expr->getRHS(); 371 return visitDivision(Dividend, Divisor, Expr); 372 } 373 374 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) { 375 assert(SDiv->getOpcode() == Instruction::SDiv && 376 "Assumed SDiv instruction!"); 377 378 auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); 379 auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); 380 return visitDivision(Dividend, Divisor, Expr, SDiv); 381 } 382 383 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 384 assert(SRem->getOpcode() == Instruction::SRem && 385 "Assumed SRem instruction!"); 386 387 auto *Divisor = SRem->getOperand(1); 388 auto *CI = dyn_cast<ConstantInt>(Divisor); 389 if (!CI || CI->isZeroValue()) 390 return visitGenericInst(SRem, S); 391 392 auto *Dividend = SRem->getOperand(0); 393 auto *DividendSCEV = SE.getSCEV(Dividend); 394 return visit(DividendSCEV); 395 } 396 397 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 398 Value *V = Expr->getValue(); 399 400 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) { 401 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer"); 402 return ValidatorResult(SCEVType::INVALID); 403 } 404 405 if (isa<UndefValue>(V)) { 406 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 407 return ValidatorResult(SCEVType::INVALID); 408 } 409 410 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 411 switch (I->getOpcode()) { 412 case Instruction::IntToPtr: 413 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 414 case Instruction::PtrToInt: 415 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 416 case Instruction::Load: 417 return visitLoadInstruction(I, Expr); 418 case Instruction::SDiv: 419 return visitSDivInstruction(I, Expr); 420 case Instruction::SRem: 421 return visitSRemInstruction(I, Expr); 422 case Instruction::Call: 423 return visitCallInstruction(I, Expr); 424 default: 425 return visitGenericInst(I, Expr); 426 } 427 } 428 429 return ValidatorResult(SCEVType::PARAM, Expr); 430 } 431 }; 432 433 class SCEVHasIVParams { 434 bool HasIVParams = false; 435 436 public: 437 SCEVHasIVParams() {} 438 439 bool follow(const SCEV *S) { 440 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 441 if (!Unknown) 442 return true; 443 444 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 445 446 if (!Call) 447 return true; 448 449 if (isConstCall(Call)) { 450 HasIVParams = true; 451 return false; 452 } 453 454 return true; 455 } 456 457 bool isDone() { return HasIVParams; } 458 bool hasIVParams() { return HasIVParams; } 459 }; 460 461 /// Check whether a SCEV refers to an SSA name defined inside a region. 462 class SCEVInRegionDependences { 463 const Region *R; 464 Loop *Scope; 465 bool AllowLoops; 466 bool HasInRegionDeps = false; 467 468 public: 469 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops) 470 : R(R), Scope(Scope), AllowLoops(AllowLoops) {} 471 472 bool follow(const SCEV *S) { 473 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) { 474 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 475 476 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 477 478 if (Call && isConstCall(Call)) 479 return false; 480 481 // Return true when Inst is defined inside the region R. 482 if (!Inst || !R->contains(Inst)) 483 return true; 484 485 HasInRegionDeps = true; 486 return false; 487 } 488 489 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 490 if (AllowLoops) 491 return true; 492 493 if (!Scope) { 494 HasInRegionDeps = true; 495 return false; 496 } 497 auto *L = AddRec->getLoop(); 498 if (R->contains(L) && !L->contains(Scope)) { 499 HasInRegionDeps = true; 500 return false; 501 } 502 } 503 504 return true; 505 } 506 bool isDone() { return false; } 507 bool hasDependences() { return HasInRegionDeps; } 508 }; 509 510 namespace polly { 511 /// Find all loops referenced in SCEVAddRecExprs. 512 class SCEVFindLoops { 513 SetVector<const Loop *> &Loops; 514 515 public: 516 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 517 518 bool follow(const SCEV *S) { 519 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 520 Loops.insert(AddRec->getLoop()); 521 return true; 522 } 523 bool isDone() { return false; } 524 }; 525 526 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 527 SCEVFindLoops FindLoops(Loops); 528 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 529 ST.visitAll(Expr); 530 } 531 532 /// Find all values referenced in SCEVUnknowns. 533 class SCEVFindValues { 534 ScalarEvolution &SE; 535 SetVector<Value *> &Values; 536 537 public: 538 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values) 539 : SE(SE), Values(Values) {} 540 541 bool follow(const SCEV *S) { 542 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 543 if (!Unknown) 544 return true; 545 546 Values.insert(Unknown->getValue()); 547 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 548 if (!Inst || (Inst->getOpcode() != Instruction::SRem && 549 Inst->getOpcode() != Instruction::SDiv)) 550 return false; 551 552 auto *Dividend = SE.getSCEV(Inst->getOperand(1)); 553 if (!isa<SCEVConstant>(Dividend)) 554 return false; 555 556 auto *Divisor = SE.getSCEV(Inst->getOperand(0)); 557 SCEVFindValues FindValues(SE, Values); 558 SCEVTraversal<SCEVFindValues> ST(FindValues); 559 ST.visitAll(Dividend); 560 ST.visitAll(Divisor); 561 562 return false; 563 } 564 bool isDone() { return false; } 565 }; 566 567 void findValues(const SCEV *Expr, ScalarEvolution &SE, 568 SetVector<Value *> &Values) { 569 SCEVFindValues FindValues(SE, Values); 570 SCEVTraversal<SCEVFindValues> ST(FindValues); 571 ST.visitAll(Expr); 572 } 573 574 bool hasIVParams(const SCEV *Expr) { 575 SCEVHasIVParams HasIVParams; 576 SCEVTraversal<SCEVHasIVParams> ST(HasIVParams); 577 ST.visitAll(Expr); 578 return HasIVParams.hasIVParams(); 579 } 580 581 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R, 582 llvm::Loop *Scope, bool AllowLoops) { 583 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops); 584 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps); 585 ST.visitAll(Expr); 586 return InRegionDeps.hasDependences(); 587 } 588 589 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, 590 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) { 591 if (isa<SCEVCouldNotCompute>(Expr)) 592 return false; 593 594 SCEVValidator Validator(R, Scope, SE, ILS); 595 DEBUG({ 596 dbgs() << "\n"; 597 dbgs() << "Expr: " << *Expr << "\n"; 598 dbgs() << "Region: " << R->getNameStr() << "\n"; 599 dbgs() << " -> "; 600 }); 601 602 ValidatorResult Result = Validator.visit(Expr); 603 604 DEBUG({ 605 if (Result.isValid()) 606 dbgs() << "VALID\n"; 607 dbgs() << "\n"; 608 }); 609 610 return Result.isValid(); 611 } 612 613 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, 614 ScalarEvolution &SE, ParameterSetTy &Params) { 615 auto *E = SE.getSCEV(V); 616 if (isa<SCEVCouldNotCompute>(E)) 617 return false; 618 619 SCEVValidator Validator(R, Scope, SE, nullptr); 620 ValidatorResult Result = Validator.visit(E); 621 if (!Result.isValid()) 622 return false; 623 624 auto ResultParams = Result.getParameters(); 625 Params.insert(ResultParams.begin(), ResultParams.end()); 626 627 return true; 628 } 629 630 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope, 631 ScalarEvolution &SE, ParameterSetTy &Params, 632 bool OrExpr) { 633 if (auto *ICmp = dyn_cast<ICmpInst>(V)) { 634 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params, 635 true) && 636 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true); 637 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) { 638 auto Opcode = BinOp->getOpcode(); 639 if (Opcode == Instruction::And || Opcode == Instruction::Or) 640 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params, 641 false) && 642 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params, 643 false); 644 /* Fall through */ 645 } 646 647 if (!OrExpr) 648 return false; 649 650 return isAffineExpr(V, R, Scope, SE, Params); 651 } 652 653 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope, 654 const SCEV *Expr, ScalarEvolution &SE) { 655 if (isa<SCEVCouldNotCompute>(Expr)) 656 return ParameterSetTy(); 657 658 InvariantLoadsSetTy ILS; 659 SCEVValidator Validator(R, Scope, SE, &ILS); 660 ValidatorResult Result = Validator.visit(Expr); 661 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 662 663 return Result.getParameters(); 664 } 665 666 std::pair<const SCEVConstant *, const SCEV *> 667 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 668 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1)); 669 670 if (auto *Constant = dyn_cast<SCEVConstant>(S)) 671 return std::make_pair(Constant, SE.getConstant(S->getType(), 1)); 672 673 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 674 if (AddRec) { 675 auto *StartExpr = AddRec->getStart(); 676 if (StartExpr->isZero()) { 677 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); 678 auto *LeftOverAddRec = 679 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), 680 AddRec->getNoWrapFlags()); 681 return std::make_pair(StepPair.first, LeftOverAddRec); 682 } 683 return std::make_pair(ConstPart, S); 684 } 685 686 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) { 687 SmallVector<const SCEV *, 4> LeftOvers; 688 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE); 689 auto *Factor = Op0Pair.first; 690 if (SE.isKnownNegative(Factor)) { 691 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor)); 692 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second)); 693 } else { 694 LeftOvers.push_back(Op0Pair.second); 695 } 696 697 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) { 698 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE); 699 // TODO: Use something smarter than equality here, e.g., gcd. 700 if (Factor == OpUPair.first) 701 LeftOvers.push_back(OpUPair.second); 702 else if (Factor == SE.getNegativeSCEV(OpUPair.first)) 703 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second)); 704 else 705 return std::make_pair(ConstPart, S); 706 } 707 708 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); 709 return std::make_pair(Factor, NewAdd); 710 } 711 712 auto *Mul = dyn_cast<SCEVMulExpr>(S); 713 if (!Mul) 714 return std::make_pair(ConstPart, S); 715 716 SmallVector<const SCEV *, 4> LeftOvers; 717 for (auto *Op : Mul->operands()) 718 if (isa<SCEVConstant>(Op)) 719 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op)); 720 else 721 LeftOvers.push_back(Op); 722 723 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers)); 724 } 725 } // namespace polly 726