1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/ScalarEvolution.h" 5 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 6 #include "llvm/Analysis/RegionInfo.h" 7 #include "llvm/Support/Debug.h" 8 9 #include <vector> 10 11 using namespace llvm; 12 13 #define DEBUG_TYPE "polly-scev-validator" 14 15 namespace SCEVType { 16 /// @brief The type of a SCEV 17 /// 18 /// To check for the validity of a SCEV we assign to each SCEV a type. The 19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 20 /// important. The subexpressions of SCEV with a type X can only have a type 21 /// that is smaller or equal than X. 22 enum TYPE { 23 // An integer value. 24 INT, 25 26 // An expression that is constant during the execution of the Scop, 27 // but that may depend on parameters unknown at compile time. 28 PARAM, 29 30 // An expression that may change during the execution of the SCoP. 31 IV, 32 33 // An invalid expression. 34 INVALID 35 }; 36 } 37 38 /// @brief The result the validator returns for a SCEV expression. 39 class ValidatorResult { 40 /// @brief The type of the expression 41 SCEVType::TYPE Type; 42 43 /// @brief The set of Parameters in the expression. 44 std::vector<const SCEV *> Parameters; 45 46 public: 47 /// @brief The copy constructor 48 ValidatorResult(const ValidatorResult &Source) { 49 Type = Source.Type; 50 Parameters = Source.Parameters; 51 } 52 53 /// @brief Construct a result with a certain type and no parameters. 54 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 55 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 56 } 57 58 /// @brief Construct a result with a certain type and a single parameter. 59 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 60 Parameters.push_back(Expr); 61 } 62 63 /// @brief Get the type of the ValidatorResult. 64 SCEVType::TYPE getType() { return Type; } 65 66 /// @brief Is the analyzed SCEV constant during the execution of the SCoP. 67 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 68 69 /// @brief Is the analyzed SCEV valid. 70 bool isValid() { return Type != SCEVType::INVALID; } 71 72 /// @brief Is the analyzed SCEV of Type IV. 73 bool isIV() { return Type == SCEVType::IV; } 74 75 /// @brief Is the analyzed SCEV of Type INT. 76 bool isINT() { return Type == SCEVType::INT; } 77 78 /// @brief Is the analyzed SCEV of Type PARAM. 79 bool isPARAM() { return Type == SCEVType::PARAM; } 80 81 /// @brief Get the parameters of this validator result. 82 std::vector<const SCEV *> getParameters() { return Parameters; } 83 84 /// @brief Add the parameters of Source to this result. 85 void addParamsFrom(const ValidatorResult &Source) { 86 Parameters.insert(Parameters.end(), Source.Parameters.begin(), 87 Source.Parameters.end()); 88 } 89 90 /// @brief Merge a result. 91 /// 92 /// This means to merge the parameters and to set the Type to the most 93 /// specific Type that matches both. 94 void merge(const ValidatorResult &ToMerge) { 95 Type = std::max(Type, ToMerge.Type); 96 addParamsFrom(ToMerge); 97 } 98 99 void print(raw_ostream &OS) { 100 switch (Type) { 101 case SCEVType::INT: 102 OS << "SCEVType::INT"; 103 break; 104 case SCEVType::PARAM: 105 OS << "SCEVType::PARAM"; 106 break; 107 case SCEVType::IV: 108 OS << "SCEVType::IV"; 109 break; 110 case SCEVType::INVALID: 111 OS << "SCEVType::INVALID"; 112 break; 113 } 114 } 115 }; 116 117 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 118 VR.print(OS); 119 return OS; 120 } 121 122 /// Check if a SCEV is valid in a SCoP. 123 struct SCEVValidator 124 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 125 private: 126 const Region *R; 127 ScalarEvolution &SE; 128 const Value *BaseAddress; 129 130 public: 131 SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress) 132 : R(R), SE(SE), BaseAddress(BaseAddress) {} 133 134 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 135 return ValidatorResult(SCEVType::INT); 136 } 137 138 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 139 ValidatorResult Op = visit(Expr->getOperand()); 140 141 switch (Op.getType()) { 142 case SCEVType::INT: 143 case SCEVType::PARAM: 144 // We currently do not represent a truncate expression as an affine 145 // expression. If it is constant during Scop execution, we treat it as a 146 // parameter. 147 return ValidatorResult(SCEVType::PARAM, Expr); 148 case SCEVType::IV: 149 DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression"); 150 return ValidatorResult(SCEVType::INVALID); 151 case SCEVType::INVALID: 152 return Op; 153 } 154 155 llvm_unreachable("Unknown SCEVType"); 156 } 157 158 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 159 ValidatorResult Op = visit(Expr->getOperand()); 160 161 switch (Op.getType()) { 162 case SCEVType::INT: 163 case SCEVType::PARAM: 164 // We currently do not represent a truncate expression as an affine 165 // expression. If it is constant during Scop execution, we treat it as a 166 // parameter. 167 return ValidatorResult(SCEVType::PARAM, Expr); 168 case SCEVType::IV: 169 DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression"); 170 return ValidatorResult(SCEVType::INVALID); 171 case SCEVType::INVALID: 172 return Op; 173 } 174 175 llvm_unreachable("Unknown SCEVType"); 176 } 177 178 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 179 // We currently allow only signed SCEV expressions. In the case of a 180 // signed value, a sign extend is a noop. 181 // 182 // TODO: Reconsider this when we add support for unsigned values. 183 return visit(Expr->getOperand()); 184 } 185 186 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 187 ValidatorResult Return(SCEVType::INT); 188 189 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 190 ValidatorResult Op = visit(Expr->getOperand(i)); 191 Return.merge(Op); 192 193 // Early exit. 194 if (!Return.isValid()) 195 break; 196 } 197 198 // TODO: Check for NSW and NUW. 199 return Return; 200 } 201 202 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 203 ValidatorResult Return(SCEVType::INT); 204 205 bool HasMultipleParams = false; 206 207 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 208 ValidatorResult Op = visit(Expr->getOperand(i)); 209 210 if (Op.isINT()) 211 continue; 212 213 if (Op.isPARAM() && Return.isPARAM()) { 214 HasMultipleParams = true; 215 continue; 216 } 217 218 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 219 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 220 << "\tExpr: " << *Expr << "\n" 221 << "\tPrevious expression type: " << Return << "\n" 222 << "\tNext operand (" << Op 223 << "): " << *Expr->getOperand(i) << "\n"); 224 225 return ValidatorResult(SCEVType::INVALID); 226 } 227 228 Return.merge(Op); 229 } 230 231 if (HasMultipleParams && Return.isValid()) 232 return ValidatorResult(SCEVType::PARAM, Expr); 233 234 // TODO: Check for NSW and NUW. 235 return Return; 236 } 237 238 class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 239 ValidatorResult LHS = visit(Expr->getLHS()); 240 ValidatorResult RHS = visit(Expr->getRHS()); 241 242 // We currently do not represent an unsigned division as an affine 243 // expression. If the division is constant during Scop execution we treat it 244 // as a parameter, otherwise we bail out. 245 if (LHS.isConstant() && RHS.isConstant()) 246 return ValidatorResult(SCEVType::PARAM, Expr); 247 248 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 249 return ValidatorResult(SCEVType::INVALID); 250 } 251 252 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 253 if (!Expr->isAffine()) { 254 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 255 return ValidatorResult(SCEVType::INVALID); 256 } 257 258 ValidatorResult Start = visit(Expr->getStart()); 259 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 260 261 if (!Start.isValid()) 262 return Start; 263 264 if (!Recurrence.isValid()) 265 return Recurrence; 266 267 if (R->contains(Expr->getLoop())) { 268 if (Recurrence.isINT()) { 269 ValidatorResult Result(SCEVType::IV); 270 Result.addParamsFrom(Start); 271 return Result; 272 } 273 274 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 275 "recurrence part"); 276 return ValidatorResult(SCEVType::INVALID); 277 } 278 279 assert(Start.isConstant() && Recurrence.isConstant() && 280 "Expected 'Start' and 'Recurrence' to be constant"); 281 282 // Directly generate ValidatorResult for Expr if 'start' is zero. 283 if (Expr->getStart()->isZero()) 284 return ValidatorResult(SCEVType::PARAM, Expr); 285 286 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 287 // if 'start' is not zero. 288 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 289 SE.getConstant(Expr->getStart()->getType(), 0), 290 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 291 292 ValidatorResult ZeroStartResult = 293 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 294 ZeroStartResult.addParamsFrom(Start); 295 296 return ZeroStartResult; 297 } 298 299 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 300 ValidatorResult Return(SCEVType::INT); 301 302 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 303 ValidatorResult Op = visit(Expr->getOperand(i)); 304 305 if (!Op.isValid()) 306 return Op; 307 308 Return.merge(Op); 309 } 310 311 return Return; 312 } 313 314 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 315 // We do not support unsigned operations. If 'Expr' is constant during Scop 316 // execution we treat this as a parameter, otherwise we bail out. 317 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 318 ValidatorResult Op = visit(Expr->getOperand(i)); 319 320 if (!Op.isConstant()) { 321 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 322 return ValidatorResult(SCEVType::INVALID); 323 } 324 } 325 326 return ValidatorResult(SCEVType::PARAM, Expr); 327 } 328 329 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 330 if (R->contains(I)) { 331 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 332 "within the region\n"); 333 return ValidatorResult(SCEVType::INVALID); 334 } 335 336 return ValidatorResult(SCEVType::PARAM, S); 337 } 338 339 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) { 340 assert(SDiv->getOpcode() == Instruction::SDiv && 341 "Assumed SDiv instruction!"); 342 343 auto *Divisor = SDiv->getOperand(1); 344 auto *CI = dyn_cast<ConstantInt>(Divisor); 345 if (!CI) 346 return visitGenericInst(SDiv, S); 347 348 auto *Dividend = SDiv->getOperand(0); 349 auto *DividendSCEV = SE.getSCEV(Dividend); 350 return visit(DividendSCEV); 351 } 352 353 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 354 Value *V = Expr->getValue(); 355 356 if (!(Expr->getType()->isIntegerTy() || Expr->getType()->isPointerTy())) { 357 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer type"); 358 return ValidatorResult(SCEVType::INVALID); 359 } 360 361 if (isa<UndefValue>(V)) { 362 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 363 return ValidatorResult(SCEVType::INVALID); 364 } 365 366 if (BaseAddress == V) { 367 DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n"); 368 return ValidatorResult(SCEVType::INVALID); 369 } 370 371 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 372 switch (I->getOpcode()) { 373 case Instruction::SDiv: 374 return visitSDivInstruction(I, Expr); 375 default: 376 return visitGenericInst(I, Expr); 377 } 378 } 379 380 return ValidatorResult(SCEVType::PARAM, Expr); 381 } 382 }; 383 384 /// @brief Check whether a SCEV refers to an SSA name defined inside a region. 385 /// 386 struct SCEVInRegionDependences 387 : public SCEVVisitor<SCEVInRegionDependences, bool> { 388 public: 389 /// Returns true when the SCEV has SSA names defined in region R. 390 static bool hasDependences(const SCEV *S, const Region *R) { 391 SCEVInRegionDependences Ignore(R); 392 return Ignore.visit(S); 393 } 394 395 SCEVInRegionDependences(const Region *R) : R(R) {} 396 397 bool visit(const SCEV *Expr) { 398 return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr); 399 } 400 401 bool visitConstant(const SCEVConstant *Constant) { return false; } 402 403 bool visitTruncateExpr(const SCEVTruncateExpr *Expr) { 404 return visit(Expr->getOperand()); 405 } 406 407 bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 408 return visit(Expr->getOperand()); 409 } 410 411 bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 412 return visit(Expr->getOperand()); 413 } 414 415 bool visitAddExpr(const SCEVAddExpr *Expr) { 416 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 417 if (visit(Expr->getOperand(i))) 418 return true; 419 420 return false; 421 } 422 423 bool visitMulExpr(const SCEVMulExpr *Expr) { 424 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 425 if (visit(Expr->getOperand(i))) 426 return true; 427 428 return false; 429 } 430 431 bool visitUDivExpr(const SCEVUDivExpr *Expr) { 432 if (visit(Expr->getLHS())) 433 return true; 434 435 if (visit(Expr->getRHS())) 436 return true; 437 438 return false; 439 } 440 441 bool visitAddRecExpr(const SCEVAddRecExpr *Expr) { 442 if (visit(Expr->getStart())) 443 return true; 444 445 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 446 if (visit(Expr->getOperand(i))) 447 return true; 448 449 return false; 450 } 451 452 bool visitSMaxExpr(const SCEVSMaxExpr *Expr) { 453 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 454 if (visit(Expr->getOperand(i))) 455 return true; 456 457 return false; 458 } 459 460 bool visitUMaxExpr(const SCEVUMaxExpr *Expr) { 461 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 462 if (visit(Expr->getOperand(i))) 463 return true; 464 465 return false; 466 } 467 468 bool visitUnknown(const SCEVUnknown *Expr) { 469 Instruction *Inst = dyn_cast<Instruction>(Expr->getValue()); 470 471 // Return true when Inst is defined inside the region R. 472 if (Inst && R->contains(Inst)) 473 return true; 474 475 return false; 476 } 477 478 private: 479 const Region *R; 480 }; 481 482 namespace polly { 483 /// Find all loops referenced in SCEVAddRecExprs. 484 class SCEVFindLoops { 485 SetVector<const Loop *> &Loops; 486 487 public: 488 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 489 490 bool follow(const SCEV *S) { 491 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 492 Loops.insert(AddRec->getLoop()); 493 return true; 494 } 495 bool isDone() { return false; } 496 }; 497 498 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 499 SCEVFindLoops FindLoops(Loops); 500 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 501 ST.visitAll(Expr); 502 } 503 504 /// Find all values referenced in SCEVUnknowns. 505 class SCEVFindValues { 506 SetVector<Value *> &Values; 507 508 public: 509 SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {} 510 511 bool follow(const SCEV *S) { 512 if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S)) 513 Values.insert(Unknown->getValue()); 514 return true; 515 } 516 bool isDone() { return false; } 517 }; 518 519 void findValues(const SCEV *Expr, SetVector<Value *> &Values) { 520 SCEVFindValues FindValues(Values); 521 SCEVTraversal<SCEVFindValues> ST(FindValues); 522 ST.visitAll(Expr); 523 } 524 525 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) { 526 return SCEVInRegionDependences::hasDependences(Expr, R); 527 } 528 529 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, 530 const Value *BaseAddress) { 531 if (isa<SCEVCouldNotCompute>(Expr)) 532 return false; 533 534 SCEVValidator Validator(R, SE, BaseAddress); 535 DEBUG({ 536 dbgs() << "\n"; 537 dbgs() << "Expr: " << *Expr << "\n"; 538 dbgs() << "Region: " << R->getNameStr() << "\n"; 539 dbgs() << " -> "; 540 }); 541 542 ValidatorResult Result = Validator.visit(Expr); 543 544 DEBUG({ 545 if (Result.isValid()) 546 dbgs() << "VALID\n"; 547 dbgs() << "\n"; 548 }); 549 550 return Result.isValid(); 551 } 552 553 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R, 554 const SCEV *Expr, 555 ScalarEvolution &SE, 556 const Value *BaseAddress) { 557 if (isa<SCEVCouldNotCompute>(Expr)) 558 return std::vector<const SCEV *>(); 559 560 SCEVValidator Validator(R, SE, BaseAddress); 561 ValidatorResult Result = Validator.visit(Expr); 562 563 return Result.getParameters(); 564 } 565 566 std::pair<const SCEV *, const SCEV *> 567 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 568 569 const SCEV *LeftOver = SE.getConstant(S->getType(), 1); 570 const SCEV *ConstPart = SE.getConstant(S->getType(), 1); 571 572 const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S); 573 if (!M) 574 return std::make_pair(ConstPart, S); 575 576 for (const SCEV *Op : M->operands()) 577 if (isa<SCEVConstant>(Op)) 578 ConstPart = SE.getMulExpr(ConstPart, Op); 579 else 580 LeftOver = SE.getMulExpr(LeftOver, Op); 581 582 return std::make_pair(ConstPart, LeftOver); 583 } 584 } 585