1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 #include <vector> 9 10 using namespace llvm; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// @brief The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } 36 37 /// @brief The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// @brief The type of the expression 40 SCEVType::TYPE Type; 41 42 /// @brief The set of Parameters in the expression. 43 std::vector<const SCEV *> Parameters; 44 45 public: 46 /// @brief The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// @brief Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// @brief Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.push_back(Expr); 60 } 61 62 /// @brief Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// @brief Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// @brief Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// @brief Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// @brief Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// @brief Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// @brief Get the parameters of this validator result. 81 std::vector<const SCEV *> getParameters() { return Parameters; } 82 83 /// @brief Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Parameters.end(), Source.Parameters.begin(), 86 Source.Parameters.end()); 87 } 88 89 /// @brief Merge a result. 90 /// 91 /// This means to merge the parameters and to set the Type to the most 92 /// specific Type that matches both. 93 void merge(const ValidatorResult &ToMerge) { 94 Type = std::max(Type, ToMerge.Type); 95 addParamsFrom(ToMerge); 96 } 97 98 void print(raw_ostream &OS) { 99 switch (Type) { 100 case SCEVType::INT: 101 OS << "SCEVType::INT"; 102 break; 103 case SCEVType::PARAM: 104 OS << "SCEVType::PARAM"; 105 break; 106 case SCEVType::IV: 107 OS << "SCEVType::IV"; 108 break; 109 case SCEVType::INVALID: 110 OS << "SCEVType::INVALID"; 111 break; 112 } 113 } 114 }; 115 116 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 117 VR.print(OS); 118 return OS; 119 } 120 121 /// Check if a SCEV is valid in a SCoP. 122 struct SCEVValidator 123 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 124 private: 125 const Region *R; 126 ScalarEvolution &SE; 127 const Value *BaseAddress; 128 129 public: 130 SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress) 131 : R(R), SE(SE), BaseAddress(BaseAddress) {} 132 133 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 134 return ValidatorResult(SCEVType::INT); 135 } 136 137 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 138 ValidatorResult Op = visit(Expr->getOperand()); 139 140 switch (Op.getType()) { 141 case SCEVType::INT: 142 case SCEVType::PARAM: 143 // We currently do not represent a truncate expression as an affine 144 // expression. If it is constant during Scop execution, we treat it as a 145 // parameter. 146 return ValidatorResult(SCEVType::PARAM, Expr); 147 case SCEVType::IV: 148 DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression"); 149 return ValidatorResult(SCEVType::INVALID); 150 case SCEVType::INVALID: 151 return Op; 152 } 153 154 llvm_unreachable("Unknown SCEVType"); 155 } 156 157 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 158 ValidatorResult Op = visit(Expr->getOperand()); 159 160 switch (Op.getType()) { 161 case SCEVType::INT: 162 case SCEVType::PARAM: 163 // We currently do not represent a truncate expression as an affine 164 // expression. If it is constant during Scop execution, we treat it as a 165 // parameter. 166 return ValidatorResult(SCEVType::PARAM, Expr); 167 case SCEVType::IV: 168 DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression"); 169 return ValidatorResult(SCEVType::INVALID); 170 case SCEVType::INVALID: 171 return Op; 172 } 173 174 llvm_unreachable("Unknown SCEVType"); 175 } 176 177 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 178 // We currently allow only signed SCEV expressions. In the case of a 179 // signed value, a sign extend is a noop. 180 // 181 // TODO: Reconsider this when we add support for unsigned values. 182 return visit(Expr->getOperand()); 183 } 184 185 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 186 ValidatorResult Return(SCEVType::INT); 187 188 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 189 ValidatorResult Op = visit(Expr->getOperand(i)); 190 Return.merge(Op); 191 192 // Early exit. 193 if (!Return.isValid()) 194 break; 195 } 196 197 // TODO: Check for NSW and NUW. 198 return Return; 199 } 200 201 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 202 ValidatorResult Return(SCEVType::INT); 203 204 bool HasMultipleParams = false; 205 206 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 207 ValidatorResult Op = visit(Expr->getOperand(i)); 208 209 if (Op.isINT()) 210 continue; 211 212 if (Op.isPARAM() && Return.isPARAM()) { 213 HasMultipleParams = true; 214 continue; 215 } 216 217 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 218 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 219 << "\tExpr: " << *Expr << "\n" 220 << "\tPrevious expression type: " << Return << "\n" 221 << "\tNext operand (" << Op 222 << "): " << *Expr->getOperand(i) << "\n"); 223 224 return ValidatorResult(SCEVType::INVALID); 225 } 226 227 Return.merge(Op); 228 } 229 230 if (HasMultipleParams && Return.isValid()) 231 return ValidatorResult(SCEVType::PARAM, Expr); 232 233 // TODO: Check for NSW and NUW. 234 return Return; 235 } 236 237 class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 238 ValidatorResult LHS = visit(Expr->getLHS()); 239 ValidatorResult RHS = visit(Expr->getRHS()); 240 241 // We currently do not represent an unsigned division as an affine 242 // expression. If the division is constant during Scop execution we treat it 243 // as a parameter, otherwise we bail out. 244 if (LHS.isConstant() && RHS.isConstant()) 245 return ValidatorResult(SCEVType::PARAM, Expr); 246 247 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 248 return ValidatorResult(SCEVType::INVALID); 249 } 250 251 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 252 if (!Expr->isAffine()) { 253 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 254 return ValidatorResult(SCEVType::INVALID); 255 } 256 257 ValidatorResult Start = visit(Expr->getStart()); 258 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 259 260 if (!Start.isValid()) 261 return Start; 262 263 if (!Recurrence.isValid()) 264 return Recurrence; 265 266 if (R->contains(Expr->getLoop())) { 267 if (Recurrence.isINT()) { 268 ValidatorResult Result(SCEVType::IV); 269 Result.addParamsFrom(Start); 270 return Result; 271 } 272 273 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 274 "recurrence part"); 275 return ValidatorResult(SCEVType::INVALID); 276 } 277 278 assert(Start.isConstant() && Recurrence.isConstant() && 279 "Expected 'Start' and 'Recurrence' to be constant"); 280 281 // Directly generate ValidatorResult for Expr if 'start' is zero. 282 if (Expr->getStart()->isZero()) 283 return ValidatorResult(SCEVType::PARAM, Expr); 284 285 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 286 // if 'start' is not zero. 287 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 288 SE.getConstant(Expr->getStart()->getType(), 0), 289 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 290 291 ValidatorResult ZeroStartResult = 292 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 293 ZeroStartResult.addParamsFrom(Start); 294 295 return ZeroStartResult; 296 } 297 298 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 299 ValidatorResult Return(SCEVType::INT); 300 301 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 302 ValidatorResult Op = visit(Expr->getOperand(i)); 303 304 if (!Op.isValid()) 305 return Op; 306 307 Return.merge(Op); 308 } 309 310 return Return; 311 } 312 313 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 314 // We do not support unsigned operations. If 'Expr' is constant during Scop 315 // execution we treat this as a parameter, otherwise we bail out. 316 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 317 ValidatorResult Op = visit(Expr->getOperand(i)); 318 319 if (!Op.isConstant()) { 320 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 321 return ValidatorResult(SCEVType::INVALID); 322 } 323 } 324 325 return ValidatorResult(SCEVType::PARAM, Expr); 326 } 327 328 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 329 if (R->contains(I)) { 330 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 331 "within the region\n"); 332 return ValidatorResult(SCEVType::INVALID); 333 } 334 335 return ValidatorResult(SCEVType::PARAM, S); 336 } 337 338 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *S) { 339 assert(SDiv->getOpcode() == Instruction::SDiv && 340 "Assumed SDiv instruction!"); 341 342 auto *Divisor = SDiv->getOperand(1); 343 auto *CI = dyn_cast<ConstantInt>(Divisor); 344 if (!CI) 345 return visitGenericInst(SDiv, S); 346 347 auto *Dividend = SDiv->getOperand(0); 348 auto *DividendSCEV = SE.getSCEV(Dividend); 349 return visit(DividendSCEV); 350 } 351 352 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 353 Value *V = Expr->getValue(); 354 355 if (!(Expr->getType()->isIntegerTy() || Expr->getType()->isPointerTy())) { 356 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer type"); 357 return ValidatorResult(SCEVType::INVALID); 358 } 359 360 if (isa<UndefValue>(V)) { 361 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 362 return ValidatorResult(SCEVType::INVALID); 363 } 364 365 if (BaseAddress == V) { 366 DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n"); 367 return ValidatorResult(SCEVType::INVALID); 368 } 369 370 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 371 switch (I->getOpcode()) { 372 case Instruction::SDiv: 373 return visitSDivInstruction(I, Expr); 374 default: 375 return visitGenericInst(I, Expr); 376 } 377 } 378 379 return ValidatorResult(SCEVType::PARAM, Expr); 380 } 381 }; 382 383 /// @brief Check whether a SCEV refers to an SSA name defined inside a region. 384 /// 385 struct SCEVInRegionDependences 386 : public SCEVVisitor<SCEVInRegionDependences, bool> { 387 public: 388 /// Returns true when the SCEV has SSA names defined in region R. 389 static bool hasDependences(const SCEV *S, const Region *R) { 390 SCEVInRegionDependences Ignore(R); 391 return Ignore.visit(S); 392 } 393 394 SCEVInRegionDependences(const Region *R) : R(R) {} 395 396 bool visit(const SCEV *Expr) { 397 return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr); 398 } 399 400 bool visitConstant(const SCEVConstant *Constant) { return false; } 401 402 bool visitTruncateExpr(const SCEVTruncateExpr *Expr) { 403 return visit(Expr->getOperand()); 404 } 405 406 bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 407 return visit(Expr->getOperand()); 408 } 409 410 bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 411 return visit(Expr->getOperand()); 412 } 413 414 bool visitAddExpr(const SCEVAddExpr *Expr) { 415 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 416 if (visit(Expr->getOperand(i))) 417 return true; 418 419 return false; 420 } 421 422 bool visitMulExpr(const SCEVMulExpr *Expr) { 423 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 424 if (visit(Expr->getOperand(i))) 425 return true; 426 427 return false; 428 } 429 430 bool visitUDivExpr(const SCEVUDivExpr *Expr) { 431 if (visit(Expr->getLHS())) 432 return true; 433 434 if (visit(Expr->getRHS())) 435 return true; 436 437 return false; 438 } 439 440 bool visitAddRecExpr(const SCEVAddRecExpr *Expr) { 441 if (visit(Expr->getStart())) 442 return true; 443 444 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 445 if (visit(Expr->getOperand(i))) 446 return true; 447 448 return false; 449 } 450 451 bool visitSMaxExpr(const SCEVSMaxExpr *Expr) { 452 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 453 if (visit(Expr->getOperand(i))) 454 return true; 455 456 return false; 457 } 458 459 bool visitUMaxExpr(const SCEVUMaxExpr *Expr) { 460 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 461 if (visit(Expr->getOperand(i))) 462 return true; 463 464 return false; 465 } 466 467 bool visitUnknown(const SCEVUnknown *Expr) { 468 Instruction *Inst = dyn_cast<Instruction>(Expr->getValue()); 469 470 // Return true when Inst is defined inside the region R. 471 if (Inst && R->contains(Inst)) 472 return true; 473 474 return false; 475 } 476 477 private: 478 const Region *R; 479 }; 480 481 namespace polly { 482 /// Find all loops referenced in SCEVAddRecExprs. 483 class SCEVFindLoops { 484 SetVector<const Loop *> &Loops; 485 486 public: 487 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 488 489 bool follow(const SCEV *S) { 490 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 491 Loops.insert(AddRec->getLoop()); 492 return true; 493 } 494 bool isDone() { return false; } 495 }; 496 497 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 498 SCEVFindLoops FindLoops(Loops); 499 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 500 ST.visitAll(Expr); 501 } 502 503 /// Find all values referenced in SCEVUnknowns. 504 class SCEVFindValues { 505 SetVector<Value *> &Values; 506 507 public: 508 SCEVFindValues(SetVector<Value *> &Values) : Values(Values) {} 509 510 bool follow(const SCEV *S) { 511 if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S)) 512 Values.insert(Unknown->getValue()); 513 return true; 514 } 515 bool isDone() { return false; } 516 }; 517 518 void findValues(const SCEV *Expr, SetVector<Value *> &Values) { 519 SCEVFindValues FindValues(Values); 520 SCEVTraversal<SCEVFindValues> ST(FindValues); 521 ST.visitAll(Expr); 522 } 523 524 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) { 525 return SCEVInRegionDependences::hasDependences(Expr, R); 526 } 527 528 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, 529 const Value *BaseAddress) { 530 if (isa<SCEVCouldNotCompute>(Expr)) 531 return false; 532 533 SCEVValidator Validator(R, SE, BaseAddress); 534 DEBUG({ 535 dbgs() << "\n"; 536 dbgs() << "Expr: " << *Expr << "\n"; 537 dbgs() << "Region: " << R->getNameStr() << "\n"; 538 dbgs() << " -> "; 539 }); 540 541 ValidatorResult Result = Validator.visit(Expr); 542 543 DEBUG({ 544 if (Result.isValid()) 545 dbgs() << "VALID\n"; 546 dbgs() << "\n"; 547 }); 548 549 return Result.isValid(); 550 } 551 552 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R, 553 const SCEV *Expr, 554 ScalarEvolution &SE, 555 const Value *BaseAddress) { 556 if (isa<SCEVCouldNotCompute>(Expr)) 557 return std::vector<const SCEV *>(); 558 559 SCEVValidator Validator(R, SE, BaseAddress); 560 ValidatorResult Result = Validator.visit(Expr); 561 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 562 563 return Result.getParameters(); 564 } 565 566 std::pair<const SCEV *, const SCEV *> 567 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 568 569 const SCEV *LeftOver = SE.getConstant(S->getType(), 1); 570 const SCEV *ConstPart = SE.getConstant(S->getType(), 1); 571 572 const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S); 573 if (!M) 574 return std::make_pair(ConstPart, S); 575 576 for (const SCEV *Op : M->operands()) 577 if (isa<SCEVConstant>(Op)) 578 ConstPart = SE.getMulExpr(ConstPart, Op); 579 else 580 LeftOver = SE.getMulExpr(LeftOver, Op); 581 582 return std::make_pair(ConstPart, LeftOver); 583 } 584 } 585