1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/ScalarEvolution.h" 5 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 6 #include "llvm/Analysis/RegionInfo.h" 7 #include "llvm/Support/Debug.h" 8 9 #include <vector> 10 11 using namespace llvm; 12 13 #define DEBUG_TYPE "polly-scev-validator" 14 15 namespace SCEVType { 16 /// @brief The type of a SCEV 17 /// 18 /// To check for the validity of a SCEV we assign to each SCEV a type. The 19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 20 /// important. The subexpressions of SCEV with a type X can only have a type 21 /// that is smaller or equal than X. 22 enum TYPE { 23 // An integer value. 24 INT, 25 26 // An expression that is constant during the execution of the Scop, 27 // but that may depend on parameters unknown at compile time. 28 PARAM, 29 30 // An expression that may change during the execution of the SCoP. 31 IV, 32 33 // An invalid expression. 34 INVALID 35 }; 36 } 37 38 /// @brief The result the validator returns for a SCEV expression. 39 class ValidatorResult { 40 /// @brief The type of the expression 41 SCEVType::TYPE Type; 42 43 /// @brief The set of Parameters in the expression. 44 std::vector<const SCEV *> Parameters; 45 46 public: 47 /// @brief The copy constructor 48 ValidatorResult(const ValidatorResult &Source) { 49 Type = Source.Type; 50 Parameters = Source.Parameters; 51 } 52 53 /// @brief Construct a result with a certain type and no parameters. 54 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 55 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 56 } 57 58 /// @brief Construct a result with a certain type and a single parameter. 59 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 60 Parameters.push_back(Expr); 61 } 62 63 /// @brief Get the type of the ValidatorResult. 64 SCEVType::TYPE getType() { return Type; } 65 66 /// @brief Is the analyzed SCEV constant during the execution of the SCoP. 67 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 68 69 /// @brief Is the analyzed SCEV valid. 70 bool isValid() { return Type != SCEVType::INVALID; } 71 72 /// @brief Is the analyzed SCEV of Type IV. 73 bool isIV() { return Type == SCEVType::IV; } 74 75 /// @brief Is the analyzed SCEV of Type INT. 76 bool isINT() { return Type == SCEVType::INT; } 77 78 /// @brief Is the analyzed SCEV of Type PARAM. 79 bool isPARAM() { return Type == SCEVType::PARAM; } 80 81 /// @brief Get the parameters of this validator result. 82 std::vector<const SCEV *> getParameters() { return Parameters; } 83 84 /// @brief Add the parameters of Source to this result. 85 void addParamsFrom(const ValidatorResult &Source) { 86 Parameters.insert(Parameters.end(), Source.Parameters.begin(), 87 Source.Parameters.end()); 88 } 89 90 /// @brief Merge a result. 91 /// 92 /// This means to merge the parameters and to set the Type to the most 93 /// specific Type that matches both. 94 ValidatorResult &merge(const ValidatorResult &ToMerge) { 95 Type = std::max(Type, ToMerge.Type); 96 addParamsFrom(ToMerge); 97 return *this; 98 } 99 100 void print(raw_ostream &OS) { 101 switch (Type) { 102 case SCEVType::INT: 103 OS << "SCEVType::INT"; 104 break; 105 case SCEVType::PARAM: 106 OS << "SCEVType::PARAM"; 107 break; 108 case SCEVType::IV: 109 OS << "SCEVType::IV"; 110 break; 111 case SCEVType::INVALID: 112 OS << "SCEVType::INVALID"; 113 break; 114 } 115 } 116 }; 117 118 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 119 VR.print(OS); 120 return OS; 121 } 122 123 /// Check if a SCEV is valid in a SCoP. 124 struct SCEVValidator 125 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 126 private: 127 const Region *R; 128 ScalarEvolution &SE; 129 const Value *BaseAddress; 130 131 public: 132 SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress) 133 : R(R), SE(SE), BaseAddress(BaseAddress) {} 134 135 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 136 return ValidatorResult(SCEVType::INT); 137 } 138 139 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 140 ValidatorResult Op = visit(Expr->getOperand()); 141 142 switch (Op.getType()) { 143 case SCEVType::INT: 144 case SCEVType::PARAM: 145 // We currently do not represent a truncate expression as an affine 146 // expression. If it is constant during Scop execution, we treat it as a 147 // parameter. 148 return ValidatorResult(SCEVType::PARAM, Expr); 149 case SCEVType::IV: 150 DEBUG(dbgs() << "INVALID: Truncation of SCEVType::IV expression"); 151 return ValidatorResult(SCEVType::INVALID); 152 case SCEVType::INVALID: 153 return Op; 154 } 155 156 llvm_unreachable("Unknown SCEVType"); 157 } 158 159 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 160 161 // Pattern matching rules to capture some bit and modulo computations: 162 // 163 // EXP % 2^C <==> 164 // [A] (i + c) & (2^C - 1) ==> zext iC {c,+,1}<%for_i> to IXX 165 // [B] (p + q) & (2^C - 1) ==> zext iC (trunc iXX %p_add_q to iC) to iXX 166 // [C] (i + p) & (2^C - 1) ==> zext iC {p & (2^C - 1),+,1}<%for_i> to iXX 167 // ==> zext iC {trunc iXX %p to iC,+,1}<%for_i> to 168 169 // Check for [A] and [C]. 170 const SCEV *OpS = Expr->getOperand(); 171 if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpS)) { 172 const SCEV *OpARStart = OpAR->getStart(); 173 174 // Special case for [C]. 175 if (auto *OpARStartTR = dyn_cast<SCEVTruncateExpr>(OpARStart)) 176 OpARStart = OpARStartTR->getOperand(); 177 178 ValidatorResult OpARStartVR = visit(OpARStart); 179 if (OpARStartVR.isConstant() && OpAR->getStepRecurrence(SE)->isOne()) 180 return OpARStartVR; 181 } 182 183 // Check for [B]. 184 if (auto *OpTR = dyn_cast<SCEVTruncateExpr>(OpS)) { 185 ValidatorResult OpTRVR = visit(OpTR->getOperand()); 186 if (OpTRVR.isConstant()) 187 return OpTRVR; 188 } 189 190 ValidatorResult Op = visit(OpS); 191 switch (Op.getType()) { 192 case SCEVType::INT: 193 case SCEVType::PARAM: 194 // We currently do not represent a truncate expression as an affine 195 // expression. If it is constant during Scop execution, we treat it as a 196 // parameter. 197 return ValidatorResult(SCEVType::PARAM, Expr); 198 case SCEVType::IV: 199 DEBUG(dbgs() << "INVALID: ZeroExtend of SCEVType::IV expression"); 200 return ValidatorResult(SCEVType::INVALID); 201 case SCEVType::INVALID: 202 return Op; 203 } 204 205 llvm_unreachable("Unknown SCEVType"); 206 } 207 208 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 209 // We currently allow only signed SCEV expressions. In the case of a 210 // signed value, a sign extend is a noop. 211 // 212 // TODO: Reconsider this when we add support for unsigned values. 213 return visit(Expr->getOperand()); 214 } 215 216 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 217 ValidatorResult Return(SCEVType::INT); 218 219 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 220 ValidatorResult Op = visit(Expr->getOperand(i)); 221 Return.merge(Op); 222 223 // Early exit. 224 if (!Return.isValid()) 225 break; 226 } 227 228 // TODO: Check for NSW and NUW. 229 return Return; 230 } 231 232 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 233 ValidatorResult Return(SCEVType::INT); 234 235 bool HasMultipleParams = false; 236 237 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 238 ValidatorResult Op = visit(Expr->getOperand(i)); 239 240 if (Op.isINT()) 241 continue; 242 243 if (Op.isPARAM() && Return.isPARAM()) { 244 HasMultipleParams = true; 245 continue; 246 } 247 248 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 249 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 250 << "\tExpr: " << *Expr << "\n" 251 << "\tPrevious expression type: " << Return << "\n" 252 << "\tNext operand (" << Op 253 << "): " << *Expr->getOperand(i) << "\n"); 254 255 return ValidatorResult(SCEVType::INVALID); 256 } 257 258 Return.merge(Op); 259 } 260 261 if (HasMultipleParams) 262 return ValidatorResult(SCEVType::PARAM, Expr); 263 264 // TODO: Check for NSW and NUW. 265 return Return; 266 } 267 268 class ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 269 ValidatorResult LHS = visit(Expr->getLHS()); 270 ValidatorResult RHS = visit(Expr->getRHS()); 271 272 // We currently do not represent an unsigned division as an affine 273 // expression. If the division is constant during Scop execution we treat it 274 // as a parameter, otherwise we bail out. 275 if (LHS.isConstant() && RHS.isConstant()) 276 return ValidatorResult(SCEVType::PARAM, Expr); 277 278 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 279 return ValidatorResult(SCEVType::INVALID); 280 } 281 282 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 283 if (!Expr->isAffine()) { 284 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 285 return ValidatorResult(SCEVType::INVALID); 286 } 287 288 ValidatorResult Start = visit(Expr->getStart()); 289 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 290 291 if (!Start.isValid()) 292 return Start; 293 294 if (!Recurrence.isValid()) 295 return Recurrence; 296 297 if (R->contains(Expr->getLoop())) { 298 if (Recurrence.isINT()) { 299 ValidatorResult Result(SCEVType::IV); 300 Result.addParamsFrom(Start); 301 return Result; 302 } 303 304 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 305 "recurrence part"); 306 return ValidatorResult(SCEVType::INVALID); 307 } 308 309 assert(Start.isConstant() && Recurrence.isConstant() && 310 "Expected 'Start' and 'Recurrence' to be constant"); 311 312 // Directly generate ValidatorResult for Expr if 'start' is zero. 313 if (Expr->getStart()->isZero()) 314 return ValidatorResult(SCEVType::PARAM, Expr); 315 316 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 317 // if 'start' is not zero. 318 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 319 SE.getConstant(Expr->getStart()->getType(), 0), 320 Expr->getStepRecurrence(SE), Expr->getLoop(), SCEV::FlagAnyWrap); 321 322 ValidatorResult ZeroStartResult = 323 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 324 ZeroStartResult.addParamsFrom(Start); 325 326 return ZeroStartResult; 327 } 328 329 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 330 ValidatorResult Return(SCEVType::INT); 331 332 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 333 ValidatorResult Op = visit(Expr->getOperand(i)); 334 335 if (!Op.isValid()) 336 return Op; 337 338 Return.merge(Op); 339 } 340 341 return Return; 342 } 343 344 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 345 // We do not support unsigned operations. If 'Expr' is constant during Scop 346 // execution we treat this as a parameter, otherwise we bail out. 347 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 348 ValidatorResult Op = visit(Expr->getOperand(i)); 349 350 if (!Op.isConstant()) { 351 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 352 return ValidatorResult(SCEVType::INVALID); 353 } 354 } 355 356 return ValidatorResult(SCEVType::PARAM, Expr); 357 } 358 359 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 360 Value *V = Expr->getValue(); 361 362 // We currently only support integer types. It may be useful to support 363 // pointer types, e.g. to support code like: 364 // 365 // if (A) 366 // A[i] = 1; 367 // 368 // See test/CodeGen/20120316-InvalidCast.ll 369 if (!Expr->getType()->isIntegerTy()) { 370 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer type"); 371 return ValidatorResult(SCEVType::INVALID); 372 } 373 374 if (isa<UndefValue>(V)) { 375 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 376 return ValidatorResult(SCEVType::INVALID); 377 } 378 379 if (auto *I = dyn_cast<Instruction>(Expr->getValue())) { 380 if (I->getOpcode() == Instruction::SRem) { 381 382 ValidatorResult Op0 = visit(SE.getSCEV(I->getOperand(0))); 383 if (!Op0.isValid()) 384 return ValidatorResult(SCEVType::INVALID); 385 386 ValidatorResult Op1 = visit(SE.getSCEV(I->getOperand(1))); 387 if (!Op1.isValid() || !Op1.isINT()) 388 return ValidatorResult(SCEVType::INVALID); 389 390 Op0.merge(Op1); 391 return Op0; 392 } 393 394 if (R->contains(I)) { 395 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 396 "within the region\n"); 397 return ValidatorResult(SCEVType::INVALID); 398 } 399 } 400 401 if (BaseAddress == V) { 402 DEBUG(dbgs() << "INVALID: UnknownExpr references BaseAddress\n"); 403 return ValidatorResult(SCEVType::INVALID); 404 } 405 406 return ValidatorResult(SCEVType::PARAM, Expr); 407 } 408 }; 409 410 /// @brief Check whether a SCEV refers to an SSA name defined inside a region. 411 /// 412 struct SCEVInRegionDependences 413 : public SCEVVisitor<SCEVInRegionDependences, bool> { 414 public: 415 /// Returns true when the SCEV has SSA names defined in region R. 416 static bool hasDependences(const SCEV *S, const Region *R) { 417 SCEVInRegionDependences Ignore(R); 418 return Ignore.visit(S); 419 } 420 421 SCEVInRegionDependences(const Region *R) : R(R) {} 422 423 bool visit(const SCEV *Expr) { 424 return SCEVVisitor<SCEVInRegionDependences, bool>::visit(Expr); 425 } 426 427 bool visitConstant(const SCEVConstant *Constant) { return false; } 428 429 bool visitTruncateExpr(const SCEVTruncateExpr *Expr) { 430 return visit(Expr->getOperand()); 431 } 432 433 bool visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 434 return visit(Expr->getOperand()); 435 } 436 437 bool visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 438 return visit(Expr->getOperand()); 439 } 440 441 bool visitAddExpr(const SCEVAddExpr *Expr) { 442 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 443 if (visit(Expr->getOperand(i))) 444 return true; 445 446 return false; 447 } 448 449 bool visitMulExpr(const SCEVMulExpr *Expr) { 450 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 451 if (visit(Expr->getOperand(i))) 452 return true; 453 454 return false; 455 } 456 457 bool visitUDivExpr(const SCEVUDivExpr *Expr) { 458 if (visit(Expr->getLHS())) 459 return true; 460 461 if (visit(Expr->getRHS())) 462 return true; 463 464 return false; 465 } 466 467 bool visitAddRecExpr(const SCEVAddRecExpr *Expr) { 468 if (visit(Expr->getStart())) 469 return true; 470 471 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 472 if (visit(Expr->getOperand(i))) 473 return true; 474 475 return false; 476 } 477 478 bool visitSMaxExpr(const SCEVSMaxExpr *Expr) { 479 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 480 if (visit(Expr->getOperand(i))) 481 return true; 482 483 return false; 484 } 485 486 bool visitUMaxExpr(const SCEVUMaxExpr *Expr) { 487 for (size_t i = 0; i < Expr->getNumOperands(); ++i) 488 if (visit(Expr->getOperand(i))) 489 return true; 490 491 return false; 492 } 493 494 bool visitUnknown(const SCEVUnknown *Expr) { 495 Instruction *Inst = dyn_cast<Instruction>(Expr->getValue()); 496 497 // Return true when Inst is defined inside the region R. 498 if (Inst && R->contains(Inst)) 499 return true; 500 501 return false; 502 } 503 504 private: 505 const Region *R; 506 }; 507 508 namespace polly { 509 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R) { 510 return SCEVInRegionDependences::hasDependences(Expr, R); 511 } 512 513 bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, 514 const Value *BaseAddress) { 515 if (isa<SCEVCouldNotCompute>(Expr)) 516 return false; 517 518 SCEVValidator Validator(R, SE, BaseAddress); 519 DEBUG(dbgs() << "\n"; dbgs() << "Expr: " << *Expr << "\n"; 520 dbgs() << "Region: " << R->getNameStr() << "\n"; dbgs() << " -> "); 521 522 ValidatorResult Result = Validator.visit(Expr); 523 524 DEBUG(if (Result.isValid()) dbgs() << "VALID\n"; dbgs() << "\n";); 525 526 return Result.isValid(); 527 } 528 529 std::vector<const SCEV *> getParamsInAffineExpr(const Region *R, 530 const SCEV *Expr, 531 ScalarEvolution &SE, 532 const Value *BaseAddress) { 533 if (isa<SCEVCouldNotCompute>(Expr)) 534 return std::vector<const SCEV *>(); 535 536 SCEVValidator Validator(R, SE, BaseAddress); 537 ValidatorResult Result = Validator.visit(Expr); 538 539 return Result.getParameters(); 540 } 541 } 542