1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopDetection.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 9 using namespace llvm; 10 using namespace polly; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } // namespace SCEVType 36 37 /// The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// The type of the expression 40 SCEVType::TYPE Type; 41 42 /// The set of Parameters in the expression. 43 ParameterSetTy Parameters; 44 45 public: 46 /// The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.insert(Expr); 60 } 61 62 /// Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// Get the parameters of this validator result. 81 const ParameterSetTy &getParameters() { return Parameters; } 82 83 /// Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end()); 86 } 87 88 /// Merge a result. 89 /// 90 /// This means to merge the parameters and to set the Type to the most 91 /// specific Type that matches both. 92 void merge(const ValidatorResult &ToMerge) { 93 Type = std::max(Type, ToMerge.Type); 94 addParamsFrom(ToMerge); 95 } 96 97 void print(raw_ostream &OS) { 98 switch (Type) { 99 case SCEVType::INT: 100 OS << "SCEVType::INT"; 101 break; 102 case SCEVType::PARAM: 103 OS << "SCEVType::PARAM"; 104 break; 105 case SCEVType::IV: 106 OS << "SCEVType::IV"; 107 break; 108 case SCEVType::INVALID: 109 OS << "SCEVType::INVALID"; 110 break; 111 } 112 } 113 }; 114 115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 116 VR.print(OS); 117 return OS; 118 } 119 120 bool polly::isConstCall(llvm::CallInst *Call) { 121 if (Call->mayReadOrWriteMemory()) 122 return false; 123 124 for (auto &Operand : Call->arg_operands()) 125 if (!isa<ConstantInt>(&Operand)) 126 return false; 127 128 return true; 129 } 130 131 /// Check if a SCEV is valid in a SCoP. 132 struct SCEVValidator 133 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 134 private: 135 const Region *R; 136 Loop *Scope; 137 ScalarEvolution &SE; 138 InvariantLoadsSetTy *ILS; 139 140 public: 141 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE, 142 InvariantLoadsSetTy *ILS) 143 : R(R), Scope(Scope), SE(SE), ILS(ILS) {} 144 145 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 146 return ValidatorResult(SCEVType::INT); 147 } 148 149 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr, 150 const SCEV *Operand) { 151 ValidatorResult Op = visit(Operand); 152 auto Type = Op.getType(); 153 154 // If unsigned operations are allowed return the operand, otherwise 155 // check if we can model the expression without unsigned assumptions. 156 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID) 157 return Op; 158 159 if (Type == SCEVType::IV) 160 return ValidatorResult(SCEVType::INVALID); 161 return ValidatorResult(SCEVType::PARAM, Expr); 162 } 163 164 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 165 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 166 } 167 168 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 169 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 170 } 171 172 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 173 return visit(Expr->getOperand()); 174 } 175 176 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 177 ValidatorResult Return(SCEVType::INT); 178 179 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 180 ValidatorResult Op = visit(Expr->getOperand(i)); 181 Return.merge(Op); 182 183 // Early exit. 184 if (!Return.isValid()) 185 break; 186 } 187 188 return Return; 189 } 190 191 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 192 ValidatorResult Return(SCEVType::INT); 193 194 bool HasMultipleParams = false; 195 196 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 197 ValidatorResult Op = visit(Expr->getOperand(i)); 198 199 if (Op.isINT()) 200 continue; 201 202 if (Op.isPARAM() && Return.isPARAM()) { 203 HasMultipleParams = true; 204 continue; 205 } 206 207 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 208 LLVM_DEBUG( 209 dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 210 << "\tExpr: " << *Expr << "\n" 211 << "\tPrevious expression type: " << Return << "\n" 212 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i) 213 << "\n"); 214 215 return ValidatorResult(SCEVType::INVALID); 216 } 217 218 Return.merge(Op); 219 } 220 221 if (HasMultipleParams && Return.isValid()) 222 return ValidatorResult(SCEVType::PARAM, Expr); 223 224 return Return; 225 } 226 227 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 228 if (!Expr->isAffine()) { 229 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine"); 230 return ValidatorResult(SCEVType::INVALID); 231 } 232 233 ValidatorResult Start = visit(Expr->getStart()); 234 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 235 236 if (!Start.isValid()) 237 return Start; 238 239 if (!Recurrence.isValid()) 240 return Recurrence; 241 242 auto *L = Expr->getLoop(); 243 if (R->contains(L) && (!Scope || !L->contains(Scope))) { 244 LLVM_DEBUG( 245 dbgs() << "INVALID: Loop of AddRec expression boxed in an a " 246 "non-affine subregion or has a non-synthesizable exit " 247 "value."); 248 return ValidatorResult(SCEVType::INVALID); 249 } 250 251 if (R->contains(L)) { 252 if (Recurrence.isINT()) { 253 ValidatorResult Result(SCEVType::IV); 254 Result.addParamsFrom(Start); 255 return Result; 256 } 257 258 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 259 "recurrence part"); 260 return ValidatorResult(SCEVType::INVALID); 261 } 262 263 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant"); 264 265 // Directly generate ValidatorResult for Expr if 'start' is zero. 266 if (Expr->getStart()->isZero()) 267 return ValidatorResult(SCEVType::PARAM, Expr); 268 269 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 270 // if 'start' is not zero. 271 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 272 SE.getConstant(Expr->getStart()->getType(), 0), 273 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 274 275 ValidatorResult ZeroStartResult = 276 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 277 ZeroStartResult.addParamsFrom(Start); 278 279 return ZeroStartResult; 280 } 281 282 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 283 ValidatorResult Return(SCEVType::INT); 284 285 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 286 ValidatorResult Op = visit(Expr->getOperand(i)); 287 288 if (!Op.isValid()) 289 return Op; 290 291 Return.merge(Op); 292 } 293 294 return Return; 295 } 296 297 class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) { 298 ValidatorResult Return(SCEVType::INT); 299 300 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 301 ValidatorResult Op = visit(Expr->getOperand(i)); 302 303 if (!Op.isValid()) 304 return Op; 305 306 Return.merge(Op); 307 } 308 309 return Return; 310 } 311 312 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 313 // We do not support unsigned max operations. If 'Expr' is constant during 314 // Scop execution we treat this as a parameter, otherwise we bail out. 315 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 316 ValidatorResult Op = visit(Expr->getOperand(i)); 317 318 if (!Op.isConstant()) { 319 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 320 return ValidatorResult(SCEVType::INVALID); 321 } 322 } 323 324 return ValidatorResult(SCEVType::PARAM, Expr); 325 } 326 327 class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) { 328 // We do not support unsigned min operations. If 'Expr' is constant during 329 // Scop execution we treat this as a parameter, otherwise we bail out. 330 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 331 ValidatorResult Op = visit(Expr->getOperand(i)); 332 333 if (!Op.isConstant()) { 334 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand"); 335 return ValidatorResult(SCEVType::INVALID); 336 } 337 } 338 339 return ValidatorResult(SCEVType::PARAM, Expr); 340 } 341 342 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 343 if (R->contains(I)) { 344 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 345 "within the region\n"); 346 return ValidatorResult(SCEVType::INVALID); 347 } 348 349 return ValidatorResult(SCEVType::PARAM, S); 350 } 351 352 ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) { 353 assert(I->getOpcode() == Instruction::Call && "Call instruction expected"); 354 355 if (R->contains(I)) { 356 auto Call = cast<CallInst>(I); 357 358 if (!isConstCall(Call)) 359 return ValidatorResult(SCEVType::INVALID, S); 360 } 361 return ValidatorResult(SCEVType::PARAM, S); 362 } 363 364 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) { 365 if (R->contains(I) && ILS) { 366 ILS->insert(cast<LoadInst>(I)); 367 return ValidatorResult(SCEVType::PARAM, S); 368 } 369 370 return visitGenericInst(I, S); 371 } 372 373 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor, 374 const SCEV *DivExpr, 375 Instruction *SDiv = nullptr) { 376 377 // First check if we might be able to model the division, thus if the 378 // divisor is constant. If so, check the dividend, otherwise check if 379 // the whole division can be seen as a parameter. 380 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero()) 381 return visit(Dividend); 382 383 // For signed divisions use the SDiv instruction to check for a parameter 384 // division, for unsigned divisions check the operands. 385 if (SDiv) 386 return visitGenericInst(SDiv, DivExpr); 387 388 ValidatorResult LHS = visit(Dividend); 389 ValidatorResult RHS = visit(Divisor); 390 if (LHS.isConstant() && RHS.isConstant()) 391 return ValidatorResult(SCEVType::PARAM, DivExpr); 392 393 LLVM_DEBUG( 394 dbgs() << "INVALID: unsigned division of non-constant expressions"); 395 return ValidatorResult(SCEVType::INVALID); 396 } 397 398 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 399 if (!PollyAllowUnsignedOperations) 400 return ValidatorResult(SCEVType::INVALID); 401 402 auto *Dividend = Expr->getLHS(); 403 auto *Divisor = Expr->getRHS(); 404 return visitDivision(Dividend, Divisor, Expr); 405 } 406 407 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) { 408 assert(SDiv->getOpcode() == Instruction::SDiv && 409 "Assumed SDiv instruction!"); 410 411 auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); 412 auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); 413 return visitDivision(Dividend, Divisor, Expr, SDiv); 414 } 415 416 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 417 assert(SRem->getOpcode() == Instruction::SRem && 418 "Assumed SRem instruction!"); 419 420 auto *Divisor = SRem->getOperand(1); 421 auto *CI = dyn_cast<ConstantInt>(Divisor); 422 if (!CI || CI->isZeroValue()) 423 return visitGenericInst(SRem, S); 424 425 auto *Dividend = SRem->getOperand(0); 426 auto *DividendSCEV = SE.getSCEV(Dividend); 427 return visit(DividendSCEV); 428 } 429 430 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 431 Value *V = Expr->getValue(); 432 433 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) { 434 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer"); 435 return ValidatorResult(SCEVType::INVALID); 436 } 437 438 if (isa<UndefValue>(V)) { 439 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 440 return ValidatorResult(SCEVType::INVALID); 441 } 442 443 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 444 switch (I->getOpcode()) { 445 case Instruction::IntToPtr: 446 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 447 case Instruction::PtrToInt: 448 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 449 case Instruction::Load: 450 return visitLoadInstruction(I, Expr); 451 case Instruction::SDiv: 452 return visitSDivInstruction(I, Expr); 453 case Instruction::SRem: 454 return visitSRemInstruction(I, Expr); 455 case Instruction::Call: 456 return visitCallInstruction(I, Expr); 457 default: 458 return visitGenericInst(I, Expr); 459 } 460 } 461 462 return ValidatorResult(SCEVType::PARAM, Expr); 463 } 464 }; 465 466 class SCEVHasIVParams { 467 bool HasIVParams = false; 468 469 public: 470 SCEVHasIVParams() {} 471 472 bool follow(const SCEV *S) { 473 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 474 if (!Unknown) 475 return true; 476 477 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 478 479 if (!Call) 480 return true; 481 482 if (isConstCall(Call)) { 483 HasIVParams = true; 484 return false; 485 } 486 487 return true; 488 } 489 490 bool isDone() { return HasIVParams; } 491 bool hasIVParams() { return HasIVParams; } 492 }; 493 494 /// Check whether a SCEV refers to an SSA name defined inside a region. 495 class SCEVInRegionDependences { 496 const Region *R; 497 Loop *Scope; 498 const InvariantLoadsSetTy &ILS; 499 bool AllowLoops; 500 bool HasInRegionDeps = false; 501 502 public: 503 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops, 504 const InvariantLoadsSetTy &ILS) 505 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {} 506 507 bool follow(const SCEV *S) { 508 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) { 509 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 510 511 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 512 513 if (Call && isConstCall(Call)) 514 return false; 515 516 if (Inst) { 517 // When we invariant load hoist a load, we first make sure that there 518 // can be no dependences created by it in the Scop region. So, we should 519 // not consider scalar dependences to `LoadInst`s that are invariant 520 // load hoisted. 521 // 522 // If this check is not present, then we create data dependences which 523 // are strictly not necessary by tracking the invariant load as a 524 // scalar. 525 LoadInst *LI = dyn_cast<LoadInst>(Inst); 526 if (LI && ILS.count(LI) > 0) 527 return false; 528 } 529 530 // Return true when Inst is defined inside the region R. 531 if (!Inst || !R->contains(Inst)) 532 return true; 533 534 HasInRegionDeps = true; 535 return false; 536 } 537 538 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 539 if (AllowLoops) 540 return true; 541 542 auto *L = AddRec->getLoop(); 543 if (R->contains(L) && !L->contains(Scope)) { 544 HasInRegionDeps = true; 545 return false; 546 } 547 } 548 549 return true; 550 } 551 bool isDone() { return false; } 552 bool hasDependences() { return HasInRegionDeps; } 553 }; 554 555 namespace polly { 556 /// Find all loops referenced in SCEVAddRecExprs. 557 class SCEVFindLoops { 558 SetVector<const Loop *> &Loops; 559 560 public: 561 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 562 563 bool follow(const SCEV *S) { 564 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 565 Loops.insert(AddRec->getLoop()); 566 return true; 567 } 568 bool isDone() { return false; } 569 }; 570 571 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 572 SCEVFindLoops FindLoops(Loops); 573 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 574 ST.visitAll(Expr); 575 } 576 577 /// Find all values referenced in SCEVUnknowns. 578 class SCEVFindValues { 579 ScalarEvolution &SE; 580 SetVector<Value *> &Values; 581 582 public: 583 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values) 584 : SE(SE), Values(Values) {} 585 586 bool follow(const SCEV *S) { 587 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 588 if (!Unknown) 589 return true; 590 591 Values.insert(Unknown->getValue()); 592 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 593 if (!Inst || (Inst->getOpcode() != Instruction::SRem && 594 Inst->getOpcode() != Instruction::SDiv)) 595 return false; 596 597 auto *Dividend = SE.getSCEV(Inst->getOperand(1)); 598 if (!isa<SCEVConstant>(Dividend)) 599 return false; 600 601 auto *Divisor = SE.getSCEV(Inst->getOperand(0)); 602 SCEVFindValues FindValues(SE, Values); 603 SCEVTraversal<SCEVFindValues> ST(FindValues); 604 ST.visitAll(Dividend); 605 ST.visitAll(Divisor); 606 607 return false; 608 } 609 bool isDone() { return false; } 610 }; 611 612 void findValues(const SCEV *Expr, ScalarEvolution &SE, 613 SetVector<Value *> &Values) { 614 SCEVFindValues FindValues(SE, Values); 615 SCEVTraversal<SCEVFindValues> ST(FindValues); 616 ST.visitAll(Expr); 617 } 618 619 bool hasIVParams(const SCEV *Expr) { 620 SCEVHasIVParams HasIVParams; 621 SCEVTraversal<SCEVHasIVParams> ST(HasIVParams); 622 ST.visitAll(Expr); 623 return HasIVParams.hasIVParams(); 624 } 625 626 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R, 627 llvm::Loop *Scope, bool AllowLoops, 628 const InvariantLoadsSetTy &ILS) { 629 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS); 630 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps); 631 ST.visitAll(Expr); 632 return InRegionDeps.hasDependences(); 633 } 634 635 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, 636 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) { 637 if (isa<SCEVCouldNotCompute>(Expr)) 638 return false; 639 640 SCEVValidator Validator(R, Scope, SE, ILS); 641 LLVM_DEBUG({ 642 dbgs() << "\n"; 643 dbgs() << "Expr: " << *Expr << "\n"; 644 dbgs() << "Region: " << R->getNameStr() << "\n"; 645 dbgs() << " -> "; 646 }); 647 648 ValidatorResult Result = Validator.visit(Expr); 649 650 LLVM_DEBUG({ 651 if (Result.isValid()) 652 dbgs() << "VALID\n"; 653 dbgs() << "\n"; 654 }); 655 656 return Result.isValid(); 657 } 658 659 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, 660 ScalarEvolution &SE, ParameterSetTy &Params) { 661 auto *E = SE.getSCEV(V); 662 if (isa<SCEVCouldNotCompute>(E)) 663 return false; 664 665 SCEVValidator Validator(R, Scope, SE, nullptr); 666 ValidatorResult Result = Validator.visit(E); 667 if (!Result.isValid()) 668 return false; 669 670 auto ResultParams = Result.getParameters(); 671 Params.insert(ResultParams.begin(), ResultParams.end()); 672 673 return true; 674 } 675 676 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope, 677 ScalarEvolution &SE, ParameterSetTy &Params, 678 bool OrExpr) { 679 if (auto *ICmp = dyn_cast<ICmpInst>(V)) { 680 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params, 681 true) && 682 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true); 683 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) { 684 auto Opcode = BinOp->getOpcode(); 685 if (Opcode == Instruction::And || Opcode == Instruction::Or) 686 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params, 687 false) && 688 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params, 689 false); 690 /* Fall through */ 691 } 692 693 if (!OrExpr) 694 return false; 695 696 return isAffineExpr(V, R, Scope, SE, Params); 697 } 698 699 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope, 700 const SCEV *Expr, ScalarEvolution &SE) { 701 if (isa<SCEVCouldNotCompute>(Expr)) 702 return ParameterSetTy(); 703 704 InvariantLoadsSetTy ILS; 705 SCEVValidator Validator(R, Scope, SE, &ILS); 706 ValidatorResult Result = Validator.visit(Expr); 707 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 708 709 return Result.getParameters(); 710 } 711 712 std::pair<const SCEVConstant *, const SCEV *> 713 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 714 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1)); 715 716 if (auto *Constant = dyn_cast<SCEVConstant>(S)) 717 return std::make_pair(Constant, SE.getConstant(S->getType(), 1)); 718 719 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 720 if (AddRec) { 721 auto *StartExpr = AddRec->getStart(); 722 if (StartExpr->isZero()) { 723 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); 724 auto *LeftOverAddRec = 725 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), 726 AddRec->getNoWrapFlags()); 727 return std::make_pair(StepPair.first, LeftOverAddRec); 728 } 729 return std::make_pair(ConstPart, S); 730 } 731 732 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) { 733 SmallVector<const SCEV *, 4> LeftOvers; 734 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE); 735 auto *Factor = Op0Pair.first; 736 if (SE.isKnownNegative(Factor)) { 737 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor)); 738 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second)); 739 } else { 740 LeftOvers.push_back(Op0Pair.second); 741 } 742 743 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) { 744 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE); 745 // TODO: Use something smarter than equality here, e.g., gcd. 746 if (Factor == OpUPair.first) 747 LeftOvers.push_back(OpUPair.second); 748 else if (Factor == SE.getNegativeSCEV(OpUPair.first)) 749 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second)); 750 else 751 return std::make_pair(ConstPart, S); 752 } 753 754 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); 755 return std::make_pair(Factor, NewAdd); 756 } 757 758 auto *Mul = dyn_cast<SCEVMulExpr>(S); 759 if (!Mul) 760 return std::make_pair(ConstPart, S); 761 762 SmallVector<const SCEV *, 4> LeftOvers; 763 for (auto *Op : Mul->operands()) 764 if (isa<SCEVConstant>(Op)) 765 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op)); 766 else 767 LeftOvers.push_back(Op); 768 769 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers)); 770 } 771 772 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R, 773 ScalarEvolution &SE, LoopInfo &LI, 774 const DominatorTree &DT) { 775 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) { 776 Value *V = Unknown->getValue(); 777 auto *PHI = dyn_cast<PHINode>(V); 778 if (!PHI) 779 return Expr; 780 781 Value *Final = nullptr; 782 783 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 784 BasicBlock *Incoming = PHI->getIncomingBlock(i); 785 if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming)) 786 continue; 787 if (Final) 788 return Expr; 789 Final = PHI->getIncomingValue(i); 790 } 791 792 if (Final) 793 return SE.getSCEV(Final); 794 } 795 return Expr; 796 } 797 798 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI, 799 const DominatorTree &DT) { 800 Value *V = nullptr; 801 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 802 BasicBlock *BB = PHI->getIncomingBlock(i); 803 if (!isErrorBlock(*BB, *R, LI, DT)) { 804 if (V) 805 return nullptr; 806 V = PHI->getIncomingValue(i); 807 } 808 } 809 810 return V; 811 } 812 } // namespace polly 813