1 //===- LegalityTest.cpp ---------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" 10 #include "llvm/Analysis/AssumptionCache.h" 11 #include "llvm/Analysis/BasicAliasAnalysis.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/ScalarEvolution.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/AsmParser/Parser.h" 16 #include "llvm/IR/DataLayout.h" 17 #include "llvm/IR/Dominators.h" 18 #include "llvm/SandboxIR/Function.h" 19 #include "llvm/SandboxIR/Instruction.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" 22 #include "gmock/gmock.h" 23 #include "gtest/gtest.h" 24 25 using namespace llvm; 26 27 struct LegalityTest : public testing::Test { 28 LLVMContext C; 29 std::unique_ptr<Module> M; 30 std::unique_ptr<DominatorTree> DT; 31 std::unique_ptr<TargetLibraryInfoImpl> TLII; 32 std::unique_ptr<TargetLibraryInfo> TLI; 33 std::unique_ptr<AssumptionCache> AC; 34 std::unique_ptr<LoopInfo> LI; 35 std::unique_ptr<ScalarEvolution> SE; 36 std::unique_ptr<BasicAAResult> BAA; 37 std::unique_ptr<AAResults> AA; 38 39 void getAnalyses(llvm::Function &LLVMF) { 40 DT = std::make_unique<DominatorTree>(LLVMF); 41 TLII = std::make_unique<TargetLibraryInfoImpl>(); 42 TLI = std::make_unique<TargetLibraryInfo>(*TLII); 43 AC = std::make_unique<AssumptionCache>(LLVMF); 44 LI = std::make_unique<LoopInfo>(*DT); 45 SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI); 46 BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(), 47 LLVMF, *TLI, *AC, DT.get()); 48 AA = std::make_unique<AAResults>(*TLI); 49 AA->addAAResult(*BAA); 50 } 51 52 void parseIR(LLVMContext &C, const char *IR) { 53 SMDiagnostic Err; 54 M = parseAssemblyString(IR, Err, C); 55 if (!M) 56 Err.print("LegalityTest", errs()); 57 } 58 }; 59 60 static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F, 61 StringRef Name) { 62 for (sandboxir::BasicBlock &BB : *F) 63 if (BB.getName() == Name) 64 return &BB; 65 llvm_unreachable("Expected to find basic block!"); 66 } 67 68 TEST_F(LegalityTest, LegalitySkipSchedule) { 69 parseIR(C, R"IR( 70 define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { 71 entry: 72 %gep0 = getelementptr float, ptr %ptr, i32 0 73 %gep1 = getelementptr float, ptr %ptr, i32 1 74 store float %farg0, ptr %gep1 75 br label %bb 76 77 bb: 78 %gep3 = getelementptr float, ptr %ptr, i32 3 79 %ld0 = load float, ptr %gep0 80 %ld0b = load float, ptr %gep0 81 %ld1 = load float, ptr %gep1 82 %ld3 = load float, ptr %gep3 83 store float %ld0, ptr %gep0 84 store float %ld1, ptr %gep1 85 store <2 x float> %vec2, ptr %gep1 86 store <3 x float> %vec3, ptr %gep3 87 store i8 %arg, ptr %gep1 88 %fadd0 = fadd float %farg0, %farg0 89 %fadd1 = fadd fast float %farg1, %farg1 90 %trunc0 = trunc nuw nsw i64 %v0 to i8 91 %trunc1 = trunc nsw i64 %v1 to i8 92 %trunc64to8 = trunc i64 %v0 to i8 93 %trunc32to8 = trunc i32 %v2 to i8 94 %cmpSLT = icmp slt i64 %v0, %v1 95 %cmpSGT = icmp sgt i64 %v0, %v1 96 ret void 97 } 98 )IR"); 99 llvm::Function *LLVMF = &*M->getFunction("foo"); 100 getAnalyses(*LLVMF); 101 const auto &DL = M->getDataLayout(); 102 103 sandboxir::Context Ctx(C); 104 auto *F = Ctx.createFunction(LLVMF); 105 auto *EntryBB = getBasicBlockByName(F, "entry"); 106 auto It = EntryBB->begin(); 107 [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); 108 [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); 109 auto *St1Entry = cast<sandboxir::StoreInst>(&*It++); 110 111 auto *BB = getBasicBlockByName(F, "bb"); 112 It = BB->begin(); 113 [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++); 114 auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); 115 auto *Ld0b = cast<sandboxir::LoadInst>(&*It++); 116 auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); 117 auto *Ld3 = cast<sandboxir::LoadInst>(&*It++); 118 auto *St0 = cast<sandboxir::StoreInst>(&*It++); 119 auto *St1 = cast<sandboxir::StoreInst>(&*It++); 120 auto *StVec2 = cast<sandboxir::StoreInst>(&*It++); 121 auto *StVec3 = cast<sandboxir::StoreInst>(&*It++); 122 auto *StI8 = cast<sandboxir::StoreInst>(&*It++); 123 auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++); 124 auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++); 125 auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++); 126 auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++); 127 auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++); 128 auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++); 129 auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++); 130 auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++); 131 132 llvm::sandboxir::InstrMaps IMaps(Ctx); 133 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); 134 const auto &Result = 135 Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); 136 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 137 138 { 139 // Check NotInstructions 140 auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true); 141 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 142 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 143 sandboxir::ResultReason::NotInstructions); 144 } 145 { 146 // Check DiffOpcodes 147 const auto &Result = 148 Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true); 149 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 150 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 151 sandboxir::ResultReason::DiffOpcodes); 152 } 153 { 154 // Check DiffTypes 155 EXPECT_TRUE(isa<sandboxir::Widen>( 156 Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true))); 157 EXPECT_TRUE(isa<sandboxir::Widen>( 158 Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true))); 159 160 const auto &Result = 161 Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true); 162 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 163 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 164 sandboxir::ResultReason::DiffTypes); 165 } 166 { 167 // Check DiffMathFlags 168 const auto &Result = 169 Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true); 170 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 171 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 172 sandboxir::ResultReason::DiffMathFlags); 173 } 174 { 175 // Check DiffWrapFlags 176 const auto &Result = 177 Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true); 178 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 179 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 180 sandboxir::ResultReason::DiffWrapFlags); 181 } 182 { 183 // Check DiffBBs 184 const auto &Result = 185 Legality.canVectorize({St0, St1Entry}, /*SkipScheduling=*/true); 186 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 187 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 188 sandboxir::ResultReason::DiffBBs); 189 } 190 { 191 // Check DiffTypes for unary operands that have a different type. 192 const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}, 193 /*SkipScheduling=*/true); 194 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 195 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 196 sandboxir::ResultReason::DiffTypes); 197 } 198 { 199 // Check DiffOpcodes for CMPs with different predicates. 200 const auto &Result = 201 Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true); 202 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 203 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 204 sandboxir::ResultReason::DiffOpcodes); 205 } 206 { 207 // Check NotConsecutive Ld0,Ld0b 208 const auto &Result = 209 Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true); 210 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 211 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 212 sandboxir::ResultReason::NotConsecutive); 213 } 214 { 215 // Check NotConsecutive Ld0,Ld3 216 const auto &Result = 217 Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true); 218 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 219 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 220 sandboxir::ResultReason::NotConsecutive); 221 } 222 { 223 // Check Widen Ld0,Ld1 224 const auto &Result = 225 Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true); 226 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 227 } 228 } 229 230 TEST_F(LegalityTest, LegalitySchedule) { 231 parseIR(C, R"IR( 232 define void @foo(ptr %ptr) { 233 %gep0 = getelementptr float, ptr %ptr, i32 0 234 %gep1 = getelementptr float, ptr %ptr, i32 1 235 %ld0 = load float, ptr %gep0 236 store float %ld0, ptr %gep1 237 %ld1 = load float, ptr %gep1 238 store float %ld0, ptr %gep0 239 store float %ld1, ptr %gep1 240 ret void 241 } 242 )IR"); 243 llvm::Function *LLVMF = &*M->getFunction("foo"); 244 getAnalyses(*LLVMF); 245 const auto &DL = M->getDataLayout(); 246 247 sandboxir::Context Ctx(C); 248 auto *F = Ctx.createFunction(LLVMF); 249 auto *BB = &*F->begin(); 250 auto It = BB->begin(); 251 [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); 252 [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); 253 auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); 254 [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++); 255 auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); 256 auto *St0 = cast<sandboxir::StoreInst>(&*It++); 257 auto *St1 = cast<sandboxir::StoreInst>(&*It++); 258 259 llvm::sandboxir::InstrMaps IMaps(Ctx); 260 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); 261 { 262 // Can vectorize St0,St1. 263 const auto &Result = Legality.canVectorize({St0, St1}); 264 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 265 } 266 { 267 // Can't vectorize Ld0,Ld1 because of conflicting store. 268 auto &Result = Legality.canVectorize({Ld0, Ld1}); 269 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 270 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 271 sandboxir::ResultReason::CantSchedule); 272 } 273 } 274 275 #ifndef NDEBUG 276 TEST_F(LegalityTest, LegalityResultDump) { 277 parseIR(C, R"IR( 278 define void @foo() { 279 ret void 280 } 281 )IR"); 282 llvm::Function *LLVMF = &*M->getFunction("foo"); 283 getAnalyses(*LLVMF); 284 const auto &DL = M->getDataLayout(); 285 286 auto Matches = [](const sandboxir::LegalityResult &Result, 287 const std::string &ExpectedStr) -> bool { 288 std::string Buff; 289 raw_string_ostream OS(Buff); 290 Result.print(OS); 291 return Buff == ExpectedStr; 292 }; 293 294 sandboxir::Context Ctx(C); 295 llvm::sandboxir::InstrMaps IMaps(Ctx); 296 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); 297 EXPECT_TRUE( 298 Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen")); 299 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 300 sandboxir::ResultReason::NotInstructions), 301 "Pack Reason: NotInstructions")); 302 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 303 sandboxir::ResultReason::DiffOpcodes), 304 "Pack Reason: DiffOpcodes")); 305 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 306 sandboxir::ResultReason::DiffTypes), 307 "Pack Reason: DiffTypes")); 308 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 309 sandboxir::ResultReason::DiffMathFlags), 310 "Pack Reason: DiffMathFlags")); 311 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 312 sandboxir::ResultReason::DiffWrapFlags), 313 "Pack Reason: DiffWrapFlags")); 314 } 315 #endif // NDEBUG 316 317 TEST_F(LegalityTest, CollectDescr) { 318 parseIR(C, R"IR( 319 define void @foo(ptr %ptr) { 320 %gep0 = getelementptr float, ptr %ptr, i32 0 321 %gep1 = getelementptr float, ptr %ptr, i32 1 322 %ld0 = load float, ptr %gep0 323 %ld1 = load float, ptr %gep1 324 %vld = load <4 x float>, ptr %ptr 325 ret void 326 } 327 )IR"); 328 llvm::Function *LLVMF = &*M->getFunction("foo"); 329 getAnalyses(*LLVMF); 330 sandboxir::Context Ctx(C); 331 auto *F = Ctx.createFunction(LLVMF); 332 auto *BB = &*F->begin(); 333 auto It = BB->begin(); 334 [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); 335 [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); 336 auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); 337 [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); 338 auto *VLd = cast<sandboxir::LoadInst>(&*It++); 339 340 sandboxir::CollectDescr::DescrVecT Descrs; 341 using EEDescr = sandboxir::CollectDescr::ExtractElementDescr; 342 343 { 344 // Check single input, no shuffle. 345 Descrs.push_back(EEDescr(VLd, 0)); 346 Descrs.push_back(EEDescr(VLd, 1)); 347 sandboxir::CollectDescr CD(std::move(Descrs)); 348 EXPECT_TRUE(CD.getSingleInput()); 349 EXPECT_EQ(CD.getSingleInput()->first, VLd); 350 EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1)); 351 EXPECT_TRUE(CD.hasVectorInputs()); 352 } 353 { 354 // Check single input, shuffle. 355 Descrs.push_back(EEDescr(VLd, 1)); 356 Descrs.push_back(EEDescr(VLd, 0)); 357 sandboxir::CollectDescr CD(std::move(Descrs)); 358 EXPECT_TRUE(CD.getSingleInput()); 359 EXPECT_EQ(CD.getSingleInput()->first, VLd); 360 EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0)); 361 EXPECT_TRUE(CD.hasVectorInputs()); 362 } 363 { 364 // Check multiple inputs. 365 Descrs.push_back(EEDescr(Ld0)); 366 Descrs.push_back(EEDescr(VLd, 0)); 367 Descrs.push_back(EEDescr(VLd, 1)); 368 sandboxir::CollectDescr CD(std::move(Descrs)); 369 EXPECT_FALSE(CD.getSingleInput()); 370 EXPECT_TRUE(CD.hasVectorInputs()); 371 } 372 { 373 // Check multiple inputs only scalars. 374 Descrs.push_back(EEDescr(Ld0)); 375 Descrs.push_back(EEDescr(Ld1)); 376 sandboxir::CollectDescr CD(std::move(Descrs)); 377 EXPECT_FALSE(CD.getSingleInput()); 378 EXPECT_FALSE(CD.hasVectorInputs()); 379 } 380 } 381 382 TEST_F(LegalityTest, ShuffleMask) { 383 { 384 // Check SmallVector constructor. 385 SmallVector<int> Indices({0, 1, 2, 3}); 386 sandboxir::ShuffleMask Mask(std::move(Indices)); 387 EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3)); 388 } 389 { 390 // Check initializer_list constructor. 391 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 392 EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3)); 393 } 394 { 395 // Check ArrayRef constructor. 396 sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3})); 397 EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3)); 398 } 399 { 400 // Check operator ArrayRef<int>(). 401 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 402 ArrayRef<int> Array = Mask; 403 EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3)); 404 } 405 { 406 // Check getIdentity(). 407 auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4); 408 EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3)); 409 EXPECT_TRUE(IdentityMask.isIdentity()); 410 } 411 { 412 // Check isIdentity(). 413 sandboxir::ShuffleMask Mask1({0, 1, 2, 3}); 414 EXPECT_TRUE(Mask1.isIdentity()); 415 sandboxir::ShuffleMask Mask2({1, 2, 3, 4}); 416 EXPECT_FALSE(Mask2.isIdentity()); 417 } 418 { 419 // Check operator==(). 420 sandboxir::ShuffleMask Mask1({0, 1, 2, 3}); 421 sandboxir::ShuffleMask Mask2({0, 1, 2, 3}); 422 EXPECT_TRUE(Mask1 == Mask2); 423 EXPECT_FALSE(Mask1 != Mask2); 424 } 425 { 426 // Check operator!=(). 427 sandboxir::ShuffleMask Mask1({0, 1, 2, 3}); 428 sandboxir::ShuffleMask Mask2({0, 1, 2, 4}); 429 EXPECT_TRUE(Mask1 != Mask2); 430 EXPECT_FALSE(Mask1 == Mask2); 431 } 432 { 433 // Check size(). 434 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 435 EXPECT_EQ(Mask.size(), 4u); 436 } 437 { 438 // Check operator[]. 439 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 440 for (auto [Idx, Elm] : enumerate(Mask)) { 441 EXPECT_EQ(Elm, Mask[Idx]); 442 } 443 } 444 { 445 // Check begin(), end(). 446 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 447 sandboxir::ShuffleMask::const_iterator Begin = Mask.begin(); 448 sandboxir::ShuffleMask::const_iterator End = Mask.begin(); 449 int Idx = 0; 450 for (auto It = Begin; It != End; ++It) { 451 EXPECT_EQ(*It, Mask[Idx++]); 452 } 453 } 454 #ifndef NDEBUG 455 { 456 // Check print(OS). 457 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 458 std::string Str; 459 raw_string_ostream OS(Str); 460 Mask.print(OS); 461 EXPECT_EQ(Str, "0,1,2,3"); 462 } 463 { 464 // Check operator<<(). 465 sandboxir::ShuffleMask Mask({0, 1, 2, 3}); 466 std::string Str; 467 raw_string_ostream OS(Str); 468 OS << Mask; 469 EXPECT_EQ(Str, "0,1,2,3"); 470 } 471 #endif // NDEBUG 472 } 473