1 //===- LegalityTest.cpp ---------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" 10 #include "llvm/Analysis/AssumptionCache.h" 11 #include "llvm/Analysis/BasicAliasAnalysis.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/ScalarEvolution.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/AsmParser/Parser.h" 16 #include "llvm/IR/DataLayout.h" 17 #include "llvm/IR/Dominators.h" 18 #include "llvm/SandboxIR/Function.h" 19 #include "llvm/SandboxIR/Instruction.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "gtest/gtest.h" 22 23 using namespace llvm; 24 25 struct LegalityTest : public testing::Test { 26 LLVMContext C; 27 std::unique_ptr<Module> M; 28 std::unique_ptr<DominatorTree> DT; 29 std::unique_ptr<TargetLibraryInfoImpl> TLII; 30 std::unique_ptr<TargetLibraryInfo> TLI; 31 std::unique_ptr<AssumptionCache> AC; 32 std::unique_ptr<LoopInfo> LI; 33 std::unique_ptr<ScalarEvolution> SE; 34 std::unique_ptr<BasicAAResult> BAA; 35 std::unique_ptr<AAResults> AA; 36 37 void getAnalyses(llvm::Function &LLVMF) { 38 DT = std::make_unique<DominatorTree>(LLVMF); 39 TLII = std::make_unique<TargetLibraryInfoImpl>(); 40 TLI = std::make_unique<TargetLibraryInfo>(*TLII); 41 AC = std::make_unique<AssumptionCache>(LLVMF); 42 LI = std::make_unique<LoopInfo>(*DT); 43 SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI); 44 BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(), 45 LLVMF, *TLI, *AC, DT.get()); 46 AA = std::make_unique<AAResults>(*TLI); 47 AA->addAAResult(*BAA); 48 } 49 50 void parseIR(LLVMContext &C, const char *IR) { 51 SMDiagnostic Err; 52 M = parseAssemblyString(IR, Err, C); 53 if (!M) 54 Err.print("LegalityTest", errs()); 55 } 56 }; 57 58 TEST_F(LegalityTest, LegalitySkipSchedule) { 59 parseIR(C, R"IR( 60 define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { 61 %gep0 = getelementptr float, ptr %ptr, i32 0 62 %gep1 = getelementptr float, ptr %ptr, i32 1 63 %gep3 = getelementptr float, ptr %ptr, i32 3 64 %ld0 = load float, ptr %gep0 65 %ld0b = load float, ptr %gep0 66 %ld1 = load float, ptr %gep1 67 %ld3 = load float, ptr %gep3 68 store float %ld0, ptr %gep0 69 store float %ld1, ptr %gep1 70 store <2 x float> %vec2, ptr %gep1 71 store <3 x float> %vec3, ptr %gep3 72 store i8 %arg, ptr %gep1 73 %fadd0 = fadd float %farg0, %farg0 74 %fadd1 = fadd fast float %farg1, %farg1 75 %trunc0 = trunc nuw nsw i64 %v0 to i8 76 %trunc1 = trunc nsw i64 %v1 to i8 77 %trunc64to8 = trunc i64 %v0 to i8 78 %trunc32to8 = trunc i32 %v2 to i8 79 %cmpSLT = icmp slt i64 %v0, %v1 80 %cmpSGT = icmp sgt i64 %v0, %v1 81 ret void 82 } 83 )IR"); 84 llvm::Function *LLVMF = &*M->getFunction("foo"); 85 getAnalyses(*LLVMF); 86 const auto &DL = M->getDataLayout(); 87 88 sandboxir::Context Ctx(C); 89 auto *F = Ctx.createFunction(LLVMF); 90 auto *BB = &*F->begin(); 91 auto It = BB->begin(); 92 [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); 93 [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); 94 [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++); 95 auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); 96 auto *Ld0b = cast<sandboxir::LoadInst>(&*It++); 97 auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); 98 auto *Ld3 = cast<sandboxir::LoadInst>(&*It++); 99 auto *St0 = cast<sandboxir::StoreInst>(&*It++); 100 auto *St1 = cast<sandboxir::StoreInst>(&*It++); 101 auto *StVec2 = cast<sandboxir::StoreInst>(&*It++); 102 auto *StVec3 = cast<sandboxir::StoreInst>(&*It++); 103 auto *StI8 = cast<sandboxir::StoreInst>(&*It++); 104 auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++); 105 auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++); 106 auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++); 107 auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++); 108 auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++); 109 auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++); 110 auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++); 111 auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++); 112 113 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); 114 const auto &Result = 115 Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); 116 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 117 118 { 119 // Check NotInstructions 120 auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true); 121 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 122 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 123 sandboxir::ResultReason::NotInstructions); 124 } 125 { 126 // Check DiffOpcodes 127 const auto &Result = 128 Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true); 129 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 130 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 131 sandboxir::ResultReason::DiffOpcodes); 132 } 133 { 134 // Check DiffTypes 135 EXPECT_TRUE(isa<sandboxir::Widen>( 136 Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true))); 137 EXPECT_TRUE(isa<sandboxir::Widen>( 138 Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true))); 139 140 const auto &Result = 141 Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true); 142 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 143 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 144 sandboxir::ResultReason::DiffTypes); 145 } 146 { 147 // Check DiffMathFlags 148 const auto &Result = 149 Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true); 150 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 151 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 152 sandboxir::ResultReason::DiffMathFlags); 153 } 154 { 155 // Check DiffWrapFlags 156 const auto &Result = 157 Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true); 158 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 159 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 160 sandboxir::ResultReason::DiffWrapFlags); 161 } 162 { 163 // Check DiffTypes for unary operands that have a different type. 164 const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}, 165 /*SkipScheduling=*/true); 166 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 167 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 168 sandboxir::ResultReason::DiffTypes); 169 } 170 { 171 // Check DiffOpcodes for CMPs with different predicates. 172 const auto &Result = 173 Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true); 174 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 175 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 176 sandboxir::ResultReason::DiffOpcodes); 177 } 178 { 179 // Check NotConsecutive Ld0,Ld0b 180 const auto &Result = 181 Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true); 182 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 183 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 184 sandboxir::ResultReason::NotConsecutive); 185 } 186 { 187 // Check NotConsecutive Ld0,Ld3 188 const auto &Result = 189 Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true); 190 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 191 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 192 sandboxir::ResultReason::NotConsecutive); 193 } 194 { 195 // Check Widen Ld0,Ld1 196 const auto &Result = 197 Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true); 198 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 199 } 200 } 201 202 TEST_F(LegalityTest, LegalitySchedule) { 203 parseIR(C, R"IR( 204 define void @foo(ptr %ptr) { 205 %gep0 = getelementptr float, ptr %ptr, i32 0 206 %gep1 = getelementptr float, ptr %ptr, i32 1 207 %ld0 = load float, ptr %gep0 208 store float %ld0, ptr %gep1 209 %ld1 = load float, ptr %gep1 210 store float %ld0, ptr %gep0 211 store float %ld1, ptr %gep1 212 ret void 213 } 214 )IR"); 215 llvm::Function *LLVMF = &*M->getFunction("foo"); 216 getAnalyses(*LLVMF); 217 const auto &DL = M->getDataLayout(); 218 219 sandboxir::Context Ctx(C); 220 auto *F = Ctx.createFunction(LLVMF); 221 auto *BB = &*F->begin(); 222 auto It = BB->begin(); 223 [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); 224 [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); 225 auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); 226 [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++); 227 auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); 228 auto *St0 = cast<sandboxir::StoreInst>(&*It++); 229 auto *St1 = cast<sandboxir::StoreInst>(&*It++); 230 231 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); 232 { 233 // Can vectorize St0,St1. 234 const auto &Result = Legality.canVectorize({St0, St1}); 235 EXPECT_TRUE(isa<sandboxir::Widen>(Result)); 236 } 237 { 238 // Can't vectorize Ld0,Ld1 because of conflicting store. 239 auto &Result = Legality.canVectorize({Ld0, Ld1}); 240 EXPECT_TRUE(isa<sandboxir::Pack>(Result)); 241 EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), 242 sandboxir::ResultReason::CantSchedule); 243 } 244 } 245 246 #ifndef NDEBUG 247 TEST_F(LegalityTest, LegalityResultDump) { 248 parseIR(C, R"IR( 249 define void @foo() { 250 ret void 251 } 252 )IR"); 253 llvm::Function *LLVMF = &*M->getFunction("foo"); 254 getAnalyses(*LLVMF); 255 const auto &DL = M->getDataLayout(); 256 257 auto Matches = [](const sandboxir::LegalityResult &Result, 258 const std::string &ExpectedStr) -> bool { 259 std::string Buff; 260 raw_string_ostream OS(Buff); 261 Result.print(OS); 262 return Buff == ExpectedStr; 263 }; 264 265 sandboxir::Context Ctx(C); 266 sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); 267 EXPECT_TRUE( 268 Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen")); 269 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 270 sandboxir::ResultReason::NotInstructions), 271 "Pack Reason: NotInstructions")); 272 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 273 sandboxir::ResultReason::DiffOpcodes), 274 "Pack Reason: DiffOpcodes")); 275 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 276 sandboxir::ResultReason::DiffTypes), 277 "Pack Reason: DiffTypes")); 278 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 279 sandboxir::ResultReason::DiffMathFlags), 280 "Pack Reason: DiffMathFlags")); 281 EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( 282 sandboxir::ResultReason::DiffWrapFlags), 283 "Pack Reason: DiffWrapFlags")); 284 } 285 #endif // NDEBUG 286