xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp (revision 5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2)
1 //===- LegalityTest.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/BasicAliasAnalysis.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolution.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/DataLayout.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/SandboxIR/Function.h"
19 #include "llvm/SandboxIR/Instruction.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "gtest/gtest.h"
22 
23 using namespace llvm;
24 
25 struct LegalityTest : public testing::Test {
26   LLVMContext C;
27   std::unique_ptr<Module> M;
28   std::unique_ptr<DominatorTree> DT;
29   std::unique_ptr<TargetLibraryInfoImpl> TLII;
30   std::unique_ptr<TargetLibraryInfo> TLI;
31   std::unique_ptr<AssumptionCache> AC;
32   std::unique_ptr<LoopInfo> LI;
33   std::unique_ptr<ScalarEvolution> SE;
34   std::unique_ptr<BasicAAResult> BAA;
35   std::unique_ptr<AAResults> AA;
36 
37   void getAnalyses(llvm::Function &LLVMF) {
38     DT = std::make_unique<DominatorTree>(LLVMF);
39     TLII = std::make_unique<TargetLibraryInfoImpl>();
40     TLI = std::make_unique<TargetLibraryInfo>(*TLII);
41     AC = std::make_unique<AssumptionCache>(LLVMF);
42     LI = std::make_unique<LoopInfo>(*DT);
43     SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
44     BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
45                                           LLVMF, *TLI, *AC, DT.get());
46     AA = std::make_unique<AAResults>(*TLI);
47     AA->addAAResult(*BAA);
48   }
49 
50   void parseIR(LLVMContext &C, const char *IR) {
51     SMDiagnostic Err;
52     M = parseAssemblyString(IR, Err, C);
53     if (!M)
54       Err.print("LegalityTest", errs());
55   }
56 };
57 
58 TEST_F(LegalityTest, LegalitySkipSchedule) {
59   parseIR(C, R"IR(
60 define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
61   %gep0 = getelementptr float, ptr %ptr, i32 0
62   %gep1 = getelementptr float, ptr %ptr, i32 1
63   %gep3 = getelementptr float, ptr %ptr, i32 3
64   %ld0 = load float, ptr %gep0
65   %ld0b = load float, ptr %gep0
66   %ld1 = load float, ptr %gep1
67   %ld3 = load float, ptr %gep3
68   store float %ld0, ptr %gep0
69   store float %ld1, ptr %gep1
70   store <2 x float> %vec2, ptr %gep1
71   store <3 x float> %vec3, ptr %gep3
72   store i8 %arg, ptr %gep1
73   %fadd0 = fadd float %farg0, %farg0
74   %fadd1 = fadd fast float %farg1, %farg1
75   %trunc0 = trunc nuw nsw i64 %v0 to i8
76   %trunc1 = trunc nsw i64 %v1 to i8
77   %trunc64to8 = trunc i64 %v0 to i8
78   %trunc32to8 = trunc i32 %v2 to i8
79   %cmpSLT = icmp slt i64 %v0, %v1
80   %cmpSGT = icmp sgt i64 %v0, %v1
81   ret void
82 }
83 )IR");
84   llvm::Function *LLVMF = &*M->getFunction("foo");
85   getAnalyses(*LLVMF);
86   const auto &DL = M->getDataLayout();
87 
88   sandboxir::Context Ctx(C);
89   auto *F = Ctx.createFunction(LLVMF);
90   auto *BB = &*F->begin();
91   auto It = BB->begin();
92   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
93   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
94   [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
95   auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
96   auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
97   auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
98   auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
99   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
100   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
101   auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
102   auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
103   auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
104   auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
105   auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
106   auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
107   auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
108   auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
109   auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
110   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
111   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
112 
113   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
114   const auto &Result =
115       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
116   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
117 
118   {
119     // Check NotInstructions
120     auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
121     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
122     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
123               sandboxir::ResultReason::NotInstructions);
124   }
125   {
126     // Check DiffOpcodes
127     const auto &Result =
128         Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
129     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
130     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
131               sandboxir::ResultReason::DiffOpcodes);
132   }
133   {
134     // Check DiffTypes
135     EXPECT_TRUE(isa<sandboxir::Widen>(
136         Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
137     EXPECT_TRUE(isa<sandboxir::Widen>(
138         Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));
139 
140     const auto &Result =
141         Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
142     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
143     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
144               sandboxir::ResultReason::DiffTypes);
145   }
146   {
147     // Check DiffMathFlags
148     const auto &Result =
149         Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
150     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
151     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
152               sandboxir::ResultReason::DiffMathFlags);
153   }
154   {
155     // Check DiffWrapFlags
156     const auto &Result =
157         Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
158     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
159     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
160               sandboxir::ResultReason::DiffWrapFlags);
161   }
162   {
163     // Check DiffTypes for unary operands that have a different type.
164     const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
165                                                /*SkipScheduling=*/true);
166     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
167     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
168               sandboxir::ResultReason::DiffTypes);
169   }
170   {
171     // Check DiffOpcodes for CMPs with different predicates.
172     const auto &Result =
173         Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
174     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
175     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
176               sandboxir::ResultReason::DiffOpcodes);
177   }
178   {
179     // Check NotConsecutive Ld0,Ld0b
180     const auto &Result =
181         Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
182     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
183     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
184               sandboxir::ResultReason::NotConsecutive);
185   }
186   {
187     // Check NotConsecutive Ld0,Ld3
188     const auto &Result =
189         Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
190     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
191     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
192               sandboxir::ResultReason::NotConsecutive);
193   }
194   {
195     // Check Widen Ld0,Ld1
196     const auto &Result =
197         Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
198     EXPECT_TRUE(isa<sandboxir::Widen>(Result));
199   }
200 }
201 
202 TEST_F(LegalityTest, LegalitySchedule) {
203   parseIR(C, R"IR(
204 define void @foo(ptr %ptr) {
205   %gep0 = getelementptr float, ptr %ptr, i32 0
206   %gep1 = getelementptr float, ptr %ptr, i32 1
207   %ld0 = load float, ptr %gep0
208   store float %ld0, ptr %gep1
209   %ld1 = load float, ptr %gep1
210   store float %ld0, ptr %gep0
211   store float %ld1, ptr %gep1
212   ret void
213 }
214 )IR");
215   llvm::Function *LLVMF = &*M->getFunction("foo");
216   getAnalyses(*LLVMF);
217   const auto &DL = M->getDataLayout();
218 
219   sandboxir::Context Ctx(C);
220   auto *F = Ctx.createFunction(LLVMF);
221   auto *BB = &*F->begin();
222   auto It = BB->begin();
223   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
224   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
225   auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
226   [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
227   auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
228   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
229   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
230 
231   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
232   {
233     // Can vectorize St0,St1.
234     const auto &Result = Legality.canVectorize({St0, St1});
235     EXPECT_TRUE(isa<sandboxir::Widen>(Result));
236   }
237   {
238     // Can't vectorize Ld0,Ld1 because of conflicting store.
239     auto &Result = Legality.canVectorize({Ld0, Ld1});
240     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
241     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
242               sandboxir::ResultReason::CantSchedule);
243   }
244 }
245 
246 #ifndef NDEBUG
247 TEST_F(LegalityTest, LegalityResultDump) {
248   parseIR(C, R"IR(
249 define void @foo() {
250   ret void
251 }
252 )IR");
253   llvm::Function *LLVMF = &*M->getFunction("foo");
254   getAnalyses(*LLVMF);
255   const auto &DL = M->getDataLayout();
256 
257   auto Matches = [](const sandboxir::LegalityResult &Result,
258                     const std::string &ExpectedStr) -> bool {
259     std::string Buff;
260     raw_string_ostream OS(Buff);
261     Result.print(OS);
262     return Buff == ExpectedStr;
263   };
264 
265   sandboxir::Context Ctx(C);
266   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
267   EXPECT_TRUE(
268       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
269   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
270                           sandboxir::ResultReason::NotInstructions),
271                       "Pack Reason: NotInstructions"));
272   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
273                           sandboxir::ResultReason::DiffOpcodes),
274                       "Pack Reason: DiffOpcodes"));
275   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
276                           sandboxir::ResultReason::DiffTypes),
277                       "Pack Reason: DiffTypes"));
278   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
279                           sandboxir::ResultReason::DiffMathFlags),
280                       "Pack Reason: DiffMathFlags"));
281   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
282                           sandboxir::ResultReason::DiffWrapFlags),
283                       "Pack Reason: DiffWrapFlags"));
284 }
285 #endif // NDEBUG
286