1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopDetection.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 9 using namespace llvm; 10 using namespace polly; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } // namespace SCEVType 36 37 /// The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// The type of the expression 40 SCEVType::TYPE Type; 41 42 /// The set of Parameters in the expression. 43 ParameterSetTy Parameters; 44 45 public: 46 /// The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.insert(Expr); 60 } 61 62 /// Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// Get the parameters of this validator result. 81 const ParameterSetTy &getParameters() { return Parameters; } 82 83 /// Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end()); 86 } 87 88 /// Merge a result. 89 /// 90 /// This means to merge the parameters and to set the Type to the most 91 /// specific Type that matches both. 92 void merge(const ValidatorResult &ToMerge) { 93 Type = std::max(Type, ToMerge.Type); 94 addParamsFrom(ToMerge); 95 } 96 97 void print(raw_ostream &OS) { 98 switch (Type) { 99 case SCEVType::INT: 100 OS << "SCEVType::INT"; 101 break; 102 case SCEVType::PARAM: 103 OS << "SCEVType::PARAM"; 104 break; 105 case SCEVType::IV: 106 OS << "SCEVType::IV"; 107 break; 108 case SCEVType::INVALID: 109 OS << "SCEVType::INVALID"; 110 break; 111 } 112 } 113 }; 114 115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 116 VR.print(OS); 117 return OS; 118 } 119 120 bool polly::isConstCall(llvm::CallInst *Call) { 121 if (Call->mayReadOrWriteMemory()) 122 return false; 123 124 for (auto &Operand : Call->arg_operands()) 125 if (!isa<ConstantInt>(&Operand)) 126 return false; 127 128 return true; 129 } 130 131 /// Check if a SCEV is valid in a SCoP. 132 struct SCEVValidator 133 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 134 private: 135 const Region *R; 136 Loop *Scope; 137 ScalarEvolution &SE; 138 InvariantLoadsSetTy *ILS; 139 140 public: 141 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE, 142 InvariantLoadsSetTy *ILS) 143 : R(R), Scope(Scope), SE(SE), ILS(ILS) {} 144 145 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 146 return ValidatorResult(SCEVType::INT); 147 } 148 149 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr, 150 const SCEV *Operand) { 151 ValidatorResult Op = visit(Operand); 152 auto Type = Op.getType(); 153 154 // If unsigned operations are allowed return the operand, otherwise 155 // check if we can model the expression without unsigned assumptions. 156 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID) 157 return Op; 158 159 if (Type == SCEVType::IV) 160 return ValidatorResult(SCEVType::INVALID); 161 return ValidatorResult(SCEVType::PARAM, Expr); 162 } 163 164 class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 165 return visit(Expr->getOperand()); 166 } 167 168 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 169 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 170 } 171 172 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 173 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 174 } 175 176 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 177 return visit(Expr->getOperand()); 178 } 179 180 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 181 ValidatorResult Return(SCEVType::INT); 182 183 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 184 ValidatorResult Op = visit(Expr->getOperand(i)); 185 Return.merge(Op); 186 187 // Early exit. 188 if (!Return.isValid()) 189 break; 190 } 191 192 return Return; 193 } 194 195 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 196 ValidatorResult Return(SCEVType::INT); 197 198 bool HasMultipleParams = false; 199 200 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 201 ValidatorResult Op = visit(Expr->getOperand(i)); 202 203 if (Op.isINT()) 204 continue; 205 206 if (Op.isPARAM() && Return.isPARAM()) { 207 HasMultipleParams = true; 208 continue; 209 } 210 211 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 212 LLVM_DEBUG( 213 dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 214 << "\tExpr: " << *Expr << "\n" 215 << "\tPrevious expression type: " << Return << "\n" 216 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i) 217 << "\n"); 218 219 return ValidatorResult(SCEVType::INVALID); 220 } 221 222 Return.merge(Op); 223 } 224 225 if (HasMultipleParams && Return.isValid()) 226 return ValidatorResult(SCEVType::PARAM, Expr); 227 228 return Return; 229 } 230 231 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 232 if (!Expr->isAffine()) { 233 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine"); 234 return ValidatorResult(SCEVType::INVALID); 235 } 236 237 ValidatorResult Start = visit(Expr->getStart()); 238 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 239 240 if (!Start.isValid()) 241 return Start; 242 243 if (!Recurrence.isValid()) 244 return Recurrence; 245 246 auto *L = Expr->getLoop(); 247 if (R->contains(L) && (!Scope || !L->contains(Scope))) { 248 LLVM_DEBUG( 249 dbgs() << "INVALID: Loop of AddRec expression boxed in an a " 250 "non-affine subregion or has a non-synthesizable exit " 251 "value."); 252 return ValidatorResult(SCEVType::INVALID); 253 } 254 255 if (R->contains(L)) { 256 if (Recurrence.isINT()) { 257 ValidatorResult Result(SCEVType::IV); 258 Result.addParamsFrom(Start); 259 return Result; 260 } 261 262 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 263 "recurrence part"); 264 return ValidatorResult(SCEVType::INVALID); 265 } 266 267 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant"); 268 269 // Directly generate ValidatorResult for Expr if 'start' is zero. 270 if (Expr->getStart()->isZero()) 271 return ValidatorResult(SCEVType::PARAM, Expr); 272 273 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 274 // if 'start' is not zero. 275 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 276 SE.getConstant(Expr->getStart()->getType(), 0), 277 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 278 279 ValidatorResult ZeroStartResult = 280 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 281 ZeroStartResult.addParamsFrom(Start); 282 283 return ZeroStartResult; 284 } 285 286 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 287 ValidatorResult Return(SCEVType::INT); 288 289 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 290 ValidatorResult Op = visit(Expr->getOperand(i)); 291 292 if (!Op.isValid()) 293 return Op; 294 295 Return.merge(Op); 296 } 297 298 return Return; 299 } 300 301 class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) { 302 ValidatorResult Return(SCEVType::INT); 303 304 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 305 ValidatorResult Op = visit(Expr->getOperand(i)); 306 307 if (!Op.isValid()) 308 return Op; 309 310 Return.merge(Op); 311 } 312 313 return Return; 314 } 315 316 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 317 // We do not support unsigned max operations. If 'Expr' is constant during 318 // Scop execution we treat this as a parameter, otherwise we bail out. 319 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 320 ValidatorResult Op = visit(Expr->getOperand(i)); 321 322 if (!Op.isConstant()) { 323 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 324 return ValidatorResult(SCEVType::INVALID); 325 } 326 } 327 328 return ValidatorResult(SCEVType::PARAM, Expr); 329 } 330 331 class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) { 332 // We do not support unsigned min operations. If 'Expr' is constant during 333 // Scop execution we treat this as a parameter, otherwise we bail out. 334 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 335 ValidatorResult Op = visit(Expr->getOperand(i)); 336 337 if (!Op.isConstant()) { 338 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand"); 339 return ValidatorResult(SCEVType::INVALID); 340 } 341 } 342 343 return ValidatorResult(SCEVType::PARAM, Expr); 344 } 345 346 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 347 if (R->contains(I)) { 348 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 349 "within the region\n"); 350 return ValidatorResult(SCEVType::INVALID); 351 } 352 353 return ValidatorResult(SCEVType::PARAM, S); 354 } 355 356 ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) { 357 assert(I->getOpcode() == Instruction::Call && "Call instruction expected"); 358 359 if (R->contains(I)) { 360 auto Call = cast<CallInst>(I); 361 362 if (!isConstCall(Call)) 363 return ValidatorResult(SCEVType::INVALID, S); 364 } 365 return ValidatorResult(SCEVType::PARAM, S); 366 } 367 368 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) { 369 if (R->contains(I) && ILS) { 370 ILS->insert(cast<LoadInst>(I)); 371 return ValidatorResult(SCEVType::PARAM, S); 372 } 373 374 return visitGenericInst(I, S); 375 } 376 377 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor, 378 const SCEV *DivExpr, 379 Instruction *SDiv = nullptr) { 380 381 // First check if we might be able to model the division, thus if the 382 // divisor is constant. If so, check the dividend, otherwise check if 383 // the whole division can be seen as a parameter. 384 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero()) 385 return visit(Dividend); 386 387 // For signed divisions use the SDiv instruction to check for a parameter 388 // division, for unsigned divisions check the operands. 389 if (SDiv) 390 return visitGenericInst(SDiv, DivExpr); 391 392 ValidatorResult LHS = visit(Dividend); 393 ValidatorResult RHS = visit(Divisor); 394 if (LHS.isConstant() && RHS.isConstant()) 395 return ValidatorResult(SCEVType::PARAM, DivExpr); 396 397 LLVM_DEBUG( 398 dbgs() << "INVALID: unsigned division of non-constant expressions"); 399 return ValidatorResult(SCEVType::INVALID); 400 } 401 402 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 403 if (!PollyAllowUnsignedOperations) 404 return ValidatorResult(SCEVType::INVALID); 405 406 auto *Dividend = Expr->getLHS(); 407 auto *Divisor = Expr->getRHS(); 408 return visitDivision(Dividend, Divisor, Expr); 409 } 410 411 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) { 412 assert(SDiv->getOpcode() == Instruction::SDiv && 413 "Assumed SDiv instruction!"); 414 415 auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); 416 auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); 417 return visitDivision(Dividend, Divisor, Expr, SDiv); 418 } 419 420 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 421 assert(SRem->getOpcode() == Instruction::SRem && 422 "Assumed SRem instruction!"); 423 424 auto *Divisor = SRem->getOperand(1); 425 auto *CI = dyn_cast<ConstantInt>(Divisor); 426 if (!CI || CI->isZeroValue()) 427 return visitGenericInst(SRem, S); 428 429 auto *Dividend = SRem->getOperand(0); 430 auto *DividendSCEV = SE.getSCEV(Dividend); 431 return visit(DividendSCEV); 432 } 433 434 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 435 Value *V = Expr->getValue(); 436 437 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) { 438 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer"); 439 return ValidatorResult(SCEVType::INVALID); 440 } 441 442 if (isa<UndefValue>(V)) { 443 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 444 return ValidatorResult(SCEVType::INVALID); 445 } 446 447 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 448 switch (I->getOpcode()) { 449 case Instruction::IntToPtr: 450 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 451 case Instruction::Load: 452 return visitLoadInstruction(I, Expr); 453 case Instruction::SDiv: 454 return visitSDivInstruction(I, Expr); 455 case Instruction::SRem: 456 return visitSRemInstruction(I, Expr); 457 case Instruction::Call: 458 return visitCallInstruction(I, Expr); 459 default: 460 return visitGenericInst(I, Expr); 461 } 462 } 463 464 if (Expr->getType()->isPointerTy()) { 465 if (isa<ConstantPointerNull>(V)) 466 return ValidatorResult(SCEVType::INT); // "int" 467 } 468 469 return ValidatorResult(SCEVType::PARAM, Expr); 470 } 471 }; 472 473 class SCEVHasIVParams { 474 bool HasIVParams = false; 475 476 public: 477 SCEVHasIVParams() {} 478 479 bool follow(const SCEV *S) { 480 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 481 if (!Unknown) 482 return true; 483 484 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 485 486 if (!Call) 487 return true; 488 489 if (isConstCall(Call)) { 490 HasIVParams = true; 491 return false; 492 } 493 494 return true; 495 } 496 497 bool isDone() { return HasIVParams; } 498 bool hasIVParams() { return HasIVParams; } 499 }; 500 501 /// Check whether a SCEV refers to an SSA name defined inside a region. 502 class SCEVInRegionDependences { 503 const Region *R; 504 Loop *Scope; 505 const InvariantLoadsSetTy &ILS; 506 bool AllowLoops; 507 bool HasInRegionDeps = false; 508 509 public: 510 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops, 511 const InvariantLoadsSetTy &ILS) 512 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {} 513 514 bool follow(const SCEV *S) { 515 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) { 516 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 517 518 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue()); 519 520 if (Call && isConstCall(Call)) 521 return false; 522 523 if (Inst) { 524 // When we invariant load hoist a load, we first make sure that there 525 // can be no dependences created by it in the Scop region. So, we should 526 // not consider scalar dependences to `LoadInst`s that are invariant 527 // load hoisted. 528 // 529 // If this check is not present, then we create data dependences which 530 // are strictly not necessary by tracking the invariant load as a 531 // scalar. 532 LoadInst *LI = dyn_cast<LoadInst>(Inst); 533 if (LI && ILS.count(LI) > 0) 534 return false; 535 } 536 537 // Return true when Inst is defined inside the region R. 538 if (!Inst || !R->contains(Inst)) 539 return true; 540 541 HasInRegionDeps = true; 542 return false; 543 } 544 545 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 546 if (AllowLoops) 547 return true; 548 549 auto *L = AddRec->getLoop(); 550 if (R->contains(L) && !L->contains(Scope)) { 551 HasInRegionDeps = true; 552 return false; 553 } 554 } 555 556 return true; 557 } 558 bool isDone() { return false; } 559 bool hasDependences() { return HasInRegionDeps; } 560 }; 561 562 namespace polly { 563 /// Find all loops referenced in SCEVAddRecExprs. 564 class SCEVFindLoops { 565 SetVector<const Loop *> &Loops; 566 567 public: 568 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 569 570 bool follow(const SCEV *S) { 571 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 572 Loops.insert(AddRec->getLoop()); 573 return true; 574 } 575 bool isDone() { return false; } 576 }; 577 578 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 579 SCEVFindLoops FindLoops(Loops); 580 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 581 ST.visitAll(Expr); 582 } 583 584 /// Find all values referenced in SCEVUnknowns. 585 class SCEVFindValues { 586 ScalarEvolution &SE; 587 SetVector<Value *> &Values; 588 589 public: 590 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values) 591 : SE(SE), Values(Values) {} 592 593 bool follow(const SCEV *S) { 594 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 595 if (!Unknown) 596 return true; 597 598 Values.insert(Unknown->getValue()); 599 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 600 if (!Inst || (Inst->getOpcode() != Instruction::SRem && 601 Inst->getOpcode() != Instruction::SDiv)) 602 return false; 603 604 auto *Dividend = SE.getSCEV(Inst->getOperand(1)); 605 if (!isa<SCEVConstant>(Dividend)) 606 return false; 607 608 auto *Divisor = SE.getSCEV(Inst->getOperand(0)); 609 SCEVFindValues FindValues(SE, Values); 610 SCEVTraversal<SCEVFindValues> ST(FindValues); 611 ST.visitAll(Dividend); 612 ST.visitAll(Divisor); 613 614 return false; 615 } 616 bool isDone() { return false; } 617 }; 618 619 void findValues(const SCEV *Expr, ScalarEvolution &SE, 620 SetVector<Value *> &Values) { 621 SCEVFindValues FindValues(SE, Values); 622 SCEVTraversal<SCEVFindValues> ST(FindValues); 623 ST.visitAll(Expr); 624 } 625 626 bool hasIVParams(const SCEV *Expr) { 627 SCEVHasIVParams HasIVParams; 628 SCEVTraversal<SCEVHasIVParams> ST(HasIVParams); 629 ST.visitAll(Expr); 630 return HasIVParams.hasIVParams(); 631 } 632 633 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R, 634 llvm::Loop *Scope, bool AllowLoops, 635 const InvariantLoadsSetTy &ILS) { 636 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS); 637 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps); 638 ST.visitAll(Expr); 639 return InRegionDeps.hasDependences(); 640 } 641 642 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, 643 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) { 644 if (isa<SCEVCouldNotCompute>(Expr)) 645 return false; 646 647 SCEVValidator Validator(R, Scope, SE, ILS); 648 LLVM_DEBUG({ 649 dbgs() << "\n"; 650 dbgs() << "Expr: " << *Expr << "\n"; 651 dbgs() << "Region: " << R->getNameStr() << "\n"; 652 dbgs() << " -> "; 653 }); 654 655 ValidatorResult Result = Validator.visit(Expr); 656 657 LLVM_DEBUG({ 658 if (Result.isValid()) 659 dbgs() << "VALID\n"; 660 dbgs() << "\n"; 661 }); 662 663 return Result.isValid(); 664 } 665 666 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, 667 ScalarEvolution &SE, ParameterSetTy &Params) { 668 auto *E = SE.getSCEV(V); 669 if (isa<SCEVCouldNotCompute>(E)) 670 return false; 671 672 SCEVValidator Validator(R, Scope, SE, nullptr); 673 ValidatorResult Result = Validator.visit(E); 674 if (!Result.isValid()) 675 return false; 676 677 auto ResultParams = Result.getParameters(); 678 Params.insert(ResultParams.begin(), ResultParams.end()); 679 680 return true; 681 } 682 683 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope, 684 ScalarEvolution &SE, ParameterSetTy &Params, 685 bool OrExpr) { 686 if (auto *ICmp = dyn_cast<ICmpInst>(V)) { 687 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params, 688 true) && 689 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true); 690 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) { 691 auto Opcode = BinOp->getOpcode(); 692 if (Opcode == Instruction::And || Opcode == Instruction::Or) 693 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params, 694 false) && 695 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params, 696 false); 697 /* Fall through */ 698 } 699 700 if (!OrExpr) 701 return false; 702 703 return isAffineExpr(V, R, Scope, SE, Params); 704 } 705 706 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope, 707 const SCEV *Expr, ScalarEvolution &SE) { 708 if (isa<SCEVCouldNotCompute>(Expr)) 709 return ParameterSetTy(); 710 711 InvariantLoadsSetTy ILS; 712 SCEVValidator Validator(R, Scope, SE, &ILS); 713 ValidatorResult Result = Validator.visit(Expr); 714 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 715 716 return Result.getParameters(); 717 } 718 719 std::pair<const SCEVConstant *, const SCEV *> 720 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 721 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1)); 722 723 if (auto *Constant = dyn_cast<SCEVConstant>(S)) 724 return std::make_pair(Constant, SE.getConstant(S->getType(), 1)); 725 726 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 727 if (AddRec) { 728 auto *StartExpr = AddRec->getStart(); 729 if (StartExpr->isZero()) { 730 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); 731 auto *LeftOverAddRec = 732 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), 733 AddRec->getNoWrapFlags()); 734 return std::make_pair(StepPair.first, LeftOverAddRec); 735 } 736 return std::make_pair(ConstPart, S); 737 } 738 739 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) { 740 SmallVector<const SCEV *, 4> LeftOvers; 741 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE); 742 auto *Factor = Op0Pair.first; 743 if (SE.isKnownNegative(Factor)) { 744 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor)); 745 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second)); 746 } else { 747 LeftOvers.push_back(Op0Pair.second); 748 } 749 750 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) { 751 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE); 752 // TODO: Use something smarter than equality here, e.g., gcd. 753 if (Factor == OpUPair.first) 754 LeftOvers.push_back(OpUPair.second); 755 else if (Factor == SE.getNegativeSCEV(OpUPair.first)) 756 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second)); 757 else 758 return std::make_pair(ConstPart, S); 759 } 760 761 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); 762 return std::make_pair(Factor, NewAdd); 763 } 764 765 auto *Mul = dyn_cast<SCEVMulExpr>(S); 766 if (!Mul) 767 return std::make_pair(ConstPart, S); 768 769 SmallVector<const SCEV *, 4> LeftOvers; 770 for (auto *Op : Mul->operands()) 771 if (isa<SCEVConstant>(Op)) 772 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op)); 773 else 774 LeftOvers.push_back(Op); 775 776 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers)); 777 } 778 779 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R, 780 ScalarEvolution &SE, LoopInfo &LI, 781 const DominatorTree &DT) { 782 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) { 783 Value *V = Unknown->getValue(); 784 auto *PHI = dyn_cast<PHINode>(V); 785 if (!PHI) 786 return Expr; 787 788 Value *Final = nullptr; 789 790 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 791 BasicBlock *Incoming = PHI->getIncomingBlock(i); 792 if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming)) 793 continue; 794 if (Final) 795 return Expr; 796 Final = PHI->getIncomingValue(i); 797 } 798 799 if (Final) 800 return SE.getSCEV(Final); 801 } 802 return Expr; 803 } 804 805 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI, 806 const DominatorTree &DT) { 807 Value *V = nullptr; 808 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 809 BasicBlock *BB = PHI->getIncomingBlock(i); 810 if (!isErrorBlock(*BB, *R, LI, DT)) { 811 if (V) 812 return nullptr; 813 V = PHI->getIncomingValue(i); 814 } 815 } 816 817 return V; 818 } 819 } // namespace polly 820