1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 #include <vector> 9 10 using namespace llvm; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// @brief The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } 36 37 /// @brief The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// @brief The type of the expression 40 SCEVType::TYPE Type; 41 42 /// @brief The set of Parameters in the expression. 43 std::vector<const SCEV *> Parameters; 44 45 public: 46 /// @brief The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// @brief Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// @brief Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.push_back(Expr); 60 } 61 62 /// @brief Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// @brief Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// @brief Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// @brief Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// @brief Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// @brief Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// @brief Get the parameters of this validator result. 81 std::vector<const SCEV *> getParameters() { return Parameters; } 82 83 /// @brief Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Parameters.end(), Source.Parameters.begin(), 86 Source.Parameters.end()); 87 } 88 89 /// @brief Merge a result. 90 /// 91 /// This means to merge the parameters and to set the Type to the most 92 /// specific Type that matches both. 93 void merge(const ValidatorResult &ToMerge) { 94 Type = std::max(Type, ToMerge.Type); 95 addParamsFrom(ToMerge); 96 } 97 98 void print(raw_ostream &OS) { 99 switch (Type) { 100 case SCEVType::INT: 101 OS << "SCEVType::INT"; 102 break; 103 case SCEVType::PARAM: 104 OS << "SCEVType::PARAM"; 105 break; 106 case SCEVType::IV: 107 OS << "SCEVType::IV"; 108 break; 109 case SCEVType::INVALID: 110 OS << "SCEVType::INVALID"; 111 break; 112 } 113 } 114 }; 115 116 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 117 VR.print(OS); 118 return OS; 119 } 120 121 /// Check if a SCEV is valid in a SCoP. 122 struct SCEVValidator 123 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 124 private: 125 const Region *R; 126 ScalarEvolution &SE; 127 const Value *BaseAddress; 128 129 public: 130 SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress) 131 : R(R), SE(SE), BaseAddress(BaseAddress) {} 132 133 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 134 return ValidatorResult(SCEVType::INT); 135 } 136 137 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 138 ValidatorResult Op = visit(Expr->getOperand()); 139 140 switch (Op.getType()) { 141 case SCEVType::INT: 142 case SCEVType::PARAM: 143 // We currently do not represent a truncate expression as an affine 144 // expression. If it is constant during Scop execution, we treat it as a 145 // parameter. 146 return ValidatorResult(SCEVType::PARAM, Expr); 147 case SCEVType::IV: 148 DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression"); 149 return ValidatorResult(SCEVType::INVALID); 150 case SCEVType::INVALID: 151 return Op; 152 } 153 154 llvm_unreachable("Unknown SCEVType"); 155 } 156 157 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 158 ValidatorResult Op = visit(Expr->getOperand()); 159 160 switch (Op.getType()) { 161 case SCEVType::INT: 162 case SCEVType::PARAM: 163 // We currently do not represent a truncate expression as an affine 164 // expression. If it is constant during Scop execution, we treat it as a 165 // parameter. 166 return ValidatorResult(SCEVType::PARAM, Expr); 167 case SCEVType::IV: 168 DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression"); 169 return ValidatorResult(SCEVType::INVALID); 170 case SCEVType::INVALID: 171 return Op; 172 } 173 174 llvm_unreachable("Unknown SCEVType"); 175 } 176 177 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 178 // We currently allow only signed SCEV expressions. In the case of a 179 // signed value, a sign extend is a noop. 180 // 181 // TODO: Reconsider this when we add support for unsigned values. 182 return visit(Expr->getOperand()); 183 } 184 185 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 186 ValidatorResult Return(SCEVType::INT); 187 188 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 189 ValidatorResult Op = visit(Expr->getOperand(i)); 190 Return.merge(Op); 191 192 // Early exit. 193 if (!Return.isValid()) 194 break; 195 } 196 197 // TODO: Check for NSW and NUW. 198 return Return; 199 } 200 201 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 202 ValidatorResult Return(SCEVType::INT); 203 204 bool HasMultipleParams = false; 205 206 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 207 ValidatorResult Op = visit(Expr->getOperand(i)); 208 209 if (Op.isINT()) 210 continue; 211 212 if (Op.isPARAM() && Return.isPARAM()) { 213 HasMultipleParams = true; 214 continue; 215 } 216 217 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 218 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 219 << "\tExpr: " << *Expr << "\n" 220 << "\tPrevious expression type: " << Return << "\n" 221 << "\tNext operand (" << Op 222 << "): " << *Expr->getOperand(i) << "\n"); 223 224 return ValidatorResult(SCEVType::INVALID); 225 } 226 227 Return.merge(Op); 228 } 229 230 if (HasMultipleParams && Return.isValid()) 231 return ValidatorResult(SCEVType::PARAM, Expr); 232 233 // TODO: Check for NSW and NUW. 234 return Return; 235 } 236 237 class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 238 ValidatorResult LHS = visit(Expr->getLHS()); 239 ValidatorResult RHS = visit(Expr->getRHS()); 240 241 // We currently do not represent an unsigned division as an affine 242 // expression. If the division is constant during Scop execution we treat it 243 // as a parameter, otherwise we bail out. 244 if (LHS.isConstant() && RHS.isConstant()) 245 return ValidatorResult(SCEVType::PARAM, Expr); 246 247 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 248 return ValidatorResult(SCEVType::INVALID); 249 } 250 251 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 252 if (!Expr->isAffine()) { 253 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 254 return ValidatorResult(SCEVType::INVALID); 255 } 256 257 ValidatorResult Start = visit(Expr->getStart()); 258 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 259 260 if (!Start.isValid()) 261 return Start; 262 263 if (!Recurrence.isValid()) 264 return Recurrence; 265 266 if (R->contains(Expr->getLoop())) { 267 if (Recurrence.isINT()) { 268 ValidatorResult Result(SCEVType::IV); 269 Result.addParamsFrom(Start); 270 return Result; 271 } 272 273 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 274 "recurrence part"); 275 return ValidatorResult(SCEVType::INVALID); 276 } 277 278 assert(Start.isConstant() && Recurrence.isConstant() && 279 "Expected 'Start' and 'Recurrence' to be constant"); 280 281 // Directly generate ValidatorResult for Expr if 'start' is zero. 282 if (Expr->getStart()->isZero()) 283 return ValidatorResult(SCEVType::PARAM, Expr); 284 285 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 286 // if 'start' is not zero. 287 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 288 SE.getConstant(Expr->getStart()->getType(), 0), 289 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 290 291 ValidatorResult ZeroStartResult = 292 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 293 ZeroStartResult.addParamsFrom(Start); 294 295 return ZeroStartResult; 296 } 297 298 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 299 ValidatorResult Return(SCEVType::INT); 300 301 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 302 ValidatorResult Op = visit(Expr->getOperand(i)); 303 304 if (!Op.isValid()) 305 return Op; 306 307 Return.merge(Op); 308 } 309 310 return Return; 311 } 312 313 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 314 // We do not support unsigned operations. If 'Expr' is constant during Scop 315 // execution we treat this as a parameter, otherwise we bail out. 316 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 317 ValidatorResult Op = visit(Expr->getOperand(i)); 318 319 if (!Op.isConstant()) { 320 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 321 return ValidatorResult(SCEVType::INVALID); 322 } 323 } 324 325 return ValidatorResult(SCEVType::PARAM, Expr); 326 } 327 328 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 329 if (R->contains(I)) { 330 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 331 "within the region\n"); 332 return ValidatorResult(SCEVType::INVALID); 333 } 334 335 return ValidatorResult(SCEVType::PARAM, S); 336 } 337 338 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) { 339 assert(SDiv->getOpcode() == Instruction::SDiv && 340 "Assumed SDiv instruction!"); 341 342 auto *Divisor = SDiv->getOperand(1); 343 auto *CI = dyn_cast<ConstantInt>(Divisor); 344 if (!CI) 345 return visitGenericInst(SDiv, S); 346 347 auto *Dividend = SDiv->getOperand(0); 348 auto *DividendSCEV = SE.getSCEV(Dividend); 349 return visit(DividendSCEV); 350 } 351 352 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 353 assert(SRem->getOpcode() == Instruction::SRem && 354 "Assumed SRem instruction!"); 355 356 auto *Divisor = SRem->getOperand(1); 357 auto *CI = dyn_cast<ConstantInt>(Divisor); 358 if (!CI) 359 return visitGenericInst(SRem, S); 360 361 auto *Dividend = SRem->getOperand(0); 362 auto *DividendSCEV = SE.getSCEV(Dividend); 363 return visit(DividendSCEV); 364 } 365 366 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 367 Value *V = Expr->getValue(); 368 369 // TODO: FIXME: IslExprBuilder is not capable of producing valid code 370 // for arbitrary pointer expressions at the moment. Until 371 // this is fixed we disallow pointer expressions completely. 372 if (Expr->getType()->isPointerTy()) { 373 DEBUG(dbgs() << "INVALID: UnknownExpr is a pointer type [FIXME]"); 374 return ValidatorResult(SCEVType::INVALID); 375 } 376 377 if (!Expr->getType()->isIntegerTy()) { 378 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer"); 379 return ValidatorResult(SCEVType::INVALID); 380 } 381 382 if (isa<UndefValue>(V)) { 383 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 384 return ValidatorResult(SCEVType::INVALID); 385 } 386 387 if (BaseAddress == V) { 388 DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n"); 389 return ValidatorResult(SCEVType::INVALID); 390 } 391 392 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 393 switch (I->getOpcode()) { 394 case Instruction::SDiv: 395 return visitSDivInstruction(I, Expr); 396 case Instruction::SRem: 397 return visitSRemInstruction(I, Expr); 398 default: 399 return visitGenericInst(I, Expr); 400 } 401 } 402 403 return ValidatorResult(SCEVType::PARAM, Expr); 404 } 405 }; 406 407 /// @brief Check whether a SCEV refers to an SSA name defined inside a region. 408 /// 409 struct SCEVInRegionDependences 410 : public SCEVVisitor<SCEVInRegionDependences, bool> { 411 public: 412 /// Returns true when the SCEV has SSA names defined in region R. 413 static bool hasDependences(const SCEV *S, const Region *R) { 414 SCEVInRegionDependences Ignore(R); 415 return Ignore.visit(S); 416 } 417 418 SCEVInRegionDependences(const Region *R) : R(R) {} 419 420 bool visit(const SCEV *Expr) { 421 return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr); 422 } 423 424 bool visitConstant(const SCEVConstant *Constant) { return false; } 425 426 bool visitTruncateExpr(const SCEVTruncateExpr *Expr) { 427 return visit(Expr->getOperand()); 428 } 429 430 bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 431 return visit(Expr->getOperand()); 432 } 433 434 bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 435 return visit(Expr->getOperand()); 436 } 437 438 bool visitAddExpr(const SCEVAddExpr *Expr) { 439 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 440 if (visit(Expr->getOperand(i))) 441 return true; 442 443 return false; 444 } 445 446 bool visitMulExpr(const SCEVMulExpr *Expr) { 447 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 448 if (visit(Expr->getOperand(i))) 449 return true; 450 451 return false; 452 } 453 454 bool visitUDivExpr(const SCEVUDivExpr *Expr) { 455 if (visit(Expr->getLHS())) 456 return true; 457 458 if (visit(Expr->getRHS())) 459 return true; 460 461 return false; 462 } 463 464 bool visitAddRecExpr(const SCEVAddRecExpr *Expr) { 465 if (visit(Expr->getStart())) 466 return true; 467 468 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 469 if (visit(Expr->getOperand(i))) 470 return true; 471 472 return false; 473 } 474 475 bool visitSMaxExpr(const SCEVSMaxExpr *Expr) { 476 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 477 if (visit(Expr->getOperand(i))) 478 return true; 479 480 return false; 481 } 482 483 bool visitUMaxExpr(const SCEVUMaxExpr *Expr) { 484 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 485 if (visit(Expr->getOperand(i))) 486 return true; 487 488 return false; 489 } 490 491 bool visitUnknown(const SCEVUnknown *Expr) { 492 Instruction *Inst = dyn_cast<Instruction>(Expr->getValue()); 493 494 // Return true when Inst is defined inside the region R. 495 if (Inst && R->contains(Inst)) 496 return true; 497 498 return false; 499 } 500 501 private: 502 const Region *R; 503 }; 504 505 namespace polly { 506 /// Find all loops referenced in SCEVAddRecExprs. 507 class SCEVFindLoops { 508 SetVector<const Loop *> &Loops; 509 510 public: 511 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 512 513 bool follow(const SCEV *S) { 514 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 515 Loops.insert(AddRec->getLoop()); 516 return true; 517 } 518 bool isDone() { return false; } 519 }; 520 521 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 522 SCEVFindLoops FindLoops(Loops); 523 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 524 ST.visitAll(Expr); 525 } 526 527 /// Find all values referenced in SCEVUnknowns. 528 class SCEVFindValues { 529 SetVector<Value *> &Values; 530 531 public: 532 SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {} 533 534 bool follow(const SCEV *S) { 535 if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S)) 536 Values.insert(Unknown->getValue()); 537 return true; 538 } 539 bool isDone() { return false; } 540 }; 541 542 void findValues(const SCEV *Expr, SetVector<Value *> &Values) { 543 SCEVFindValues FindValues(Values); 544 SCEVTraversal<SCEVFindValues> ST(FindValues); 545 ST.visitAll(Expr); 546 } 547 548 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) { 549 return SCEVInRegionDependences::hasDependences(Expr, R); 550 } 551 552 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, 553 const Value *BaseAddress) { 554 if (isa<SCEVCouldNotCompute>(Expr)) 555 return false; 556 557 SCEVValidator Validator(R, SE, BaseAddress); 558 DEBUG({ 559 dbgs() << "\n"; 560 dbgs() << "Expr: " << *Expr << "\n"; 561 dbgs() << "Region: " << R->getNameStr() << "\n"; 562 dbgs() << " -> "; 563 }); 564 565 ValidatorResult Result = Validator.visit(Expr); 566 567 DEBUG({ 568 if (Result.isValid()) 569 dbgs() << "VALID\n"; 570 dbgs() << "\n"; 571 }); 572 573 return Result.isValid(); 574 } 575 576 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R, 577 const SCEV *Expr, 578 ScalarEvolution &SE, 579 const Value *BaseAddress) { 580 if (isa<SCEVCouldNotCompute>(Expr)) 581 return std::vector<const SCEV *>(); 582 583 SCEVValidator Validator(R, SE, BaseAddress); 584 ValidatorResult Result = Validator.visit(Expr); 585 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 586 587 return Result.getParameters(); 588 } 589 590 std::pair<const SCEV *, const SCEV *> 591 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 592 593 const SCEV *LeftOver = SE.getConstant(S->getType(), 1); 594 const SCEV *ConstPart = SE.getConstant(S->getType(), 1); 595 596 const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S); 597 if (!M) 598 return std::make_pair(ConstPart, S); 599 600 for (const SCEV *Op : M->operands()) 601 if (isa<SCEVConstant>(Op)) 602 ConstPart = SE.getMulExpr(ConstPart, Op); 603 else 604 LeftOver = SE.getMulExpr(LeftOver, Op); 605 606 return std::make_pair(ConstPart, LeftOver); 607 } 608 } 609