xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp (revision 6409799bdcd86be3ed72e8d172181294d3e5ad09)
1 //===- LegalityTest.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/BasicAliasAnalysis.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolution.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/DataLayout.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/SandboxIR/Function.h"
19 #include "llvm/SandboxIR/Instruction.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 
25 using namespace llvm;
26 
27 struct LegalityTest : public testing::Test {
28   LLVMContext C;
29   std::unique_ptr<Module> M;
30   std::unique_ptr<DominatorTree> DT;
31   std::unique_ptr<TargetLibraryInfoImpl> TLII;
32   std::unique_ptr<TargetLibraryInfo> TLI;
33   std::unique_ptr<AssumptionCache> AC;
34   std::unique_ptr<LoopInfo> LI;
35   std::unique_ptr<ScalarEvolution> SE;
36   std::unique_ptr<BasicAAResult> BAA;
37   std::unique_ptr<AAResults> AA;
38 
39   void getAnalyses(llvm::Function &LLVMF) {
40     DT = std::make_unique<DominatorTree>(LLVMF);
41     TLII = std::make_unique<TargetLibraryInfoImpl>();
42     TLI = std::make_unique<TargetLibraryInfo>(*TLII);
43     AC = std::make_unique<AssumptionCache>(LLVMF);
44     LI = std::make_unique<LoopInfo>(*DT);
45     SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
46     BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
47                                           LLVMF, *TLI, *AC, DT.get());
48     AA = std::make_unique<AAResults>(*TLI);
49     AA->addAAResult(*BAA);
50   }
51 
52   void parseIR(LLVMContext &C, const char *IR) {
53     SMDiagnostic Err;
54     M = parseAssemblyString(IR, Err, C);
55     if (!M)
56       Err.print("LegalityTest", errs());
57   }
58 };
59 
60 static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
61                                                   StringRef Name) {
62   for (sandboxir::BasicBlock &BB : *F)
63     if (BB.getName() == Name)
64       return &BB;
65   llvm_unreachable("Expected to find basic block!");
66 }
67 
68 TEST_F(LegalityTest, LegalitySkipSchedule) {
69   parseIR(C, R"IR(
70 define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
71 entry:
72   %gep0 = getelementptr float, ptr %ptr, i32 0
73   %gep1 = getelementptr float, ptr %ptr, i32 1
74   store float %farg0, ptr %gep1
75   br label %bb
76 
77 bb:
78   %gep3 = getelementptr float, ptr %ptr, i32 3
79   %ld0 = load float, ptr %gep0
80   %ld0b = load float, ptr %gep0
81   %ld1 = load float, ptr %gep1
82   %ld3 = load float, ptr %gep3
83   store float %ld0, ptr %gep0
84   store float %ld1, ptr %gep1
85   store <2 x float> %vec2, ptr %gep1
86   store <3 x float> %vec3, ptr %gep3
87   store i8 %arg, ptr %gep1
88   %fadd0 = fadd float %farg0, %farg0
89   %fadd1 = fadd fast float %farg1, %farg1
90   %trunc0 = trunc nuw nsw i64 %v0 to i8
91   %trunc1 = trunc nsw i64 %v1 to i8
92   %trunc64to8 = trunc i64 %v0 to i8
93   %trunc32to8 = trunc i32 %v2 to i8
94   %cmpSLT = icmp slt i64 %v0, %v1
95   %cmpSGT = icmp sgt i64 %v0, %v1
96   ret void
97 }
98 )IR");
99   llvm::Function *LLVMF = &*M->getFunction("foo");
100   getAnalyses(*LLVMF);
101   const auto &DL = M->getDataLayout();
102 
103   sandboxir::Context Ctx(C);
104   auto *F = Ctx.createFunction(LLVMF);
105   auto *EntryBB = getBasicBlockByName(F, "entry");
106   auto It = EntryBB->begin();
107   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
108   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
109   auto *St1Entry = cast<sandboxir::StoreInst>(&*It++);
110 
111   auto *BB = getBasicBlockByName(F, "bb");
112   It = BB->begin();
113   [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
114   auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
115   auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
116   auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
117   auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
118   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
119   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
120   auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
121   auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
122   auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
123   auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
124   auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
125   auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
126   auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
127   auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
128   auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
129   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
130   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
131 
132   llvm::sandboxir::InstrMaps IMaps(Ctx);
133   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
134   const auto &Result =
135       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
136   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
137 
138   {
139     // Check NotInstructions
140     auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
141     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
142     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
143               sandboxir::ResultReason::NotInstructions);
144   }
145   {
146     // Check DiffOpcodes
147     const auto &Result =
148         Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
149     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
150     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
151               sandboxir::ResultReason::DiffOpcodes);
152   }
153   {
154     // Check DiffTypes
155     EXPECT_TRUE(isa<sandboxir::Widen>(
156         Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
157     EXPECT_TRUE(isa<sandboxir::Widen>(
158         Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));
159 
160     const auto &Result =
161         Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
162     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
163     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
164               sandboxir::ResultReason::DiffTypes);
165   }
166   {
167     // Check DiffMathFlags
168     const auto &Result =
169         Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
170     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
171     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
172               sandboxir::ResultReason::DiffMathFlags);
173   }
174   {
175     // Check DiffWrapFlags
176     const auto &Result =
177         Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
178     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
179     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
180               sandboxir::ResultReason::DiffWrapFlags);
181   }
182   {
183     // Check DiffBBs
184     const auto &Result =
185         Legality.canVectorize({St0, St1Entry}, /*SkipScheduling=*/true);
186     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
187     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
188               sandboxir::ResultReason::DiffBBs);
189   }
190   {
191     // Check DiffTypes for unary operands that have a different type.
192     const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
193                                                /*SkipScheduling=*/true);
194     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
195     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
196               sandboxir::ResultReason::DiffTypes);
197   }
198   {
199     // Check DiffOpcodes for CMPs with different predicates.
200     const auto &Result =
201         Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
202     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
203     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
204               sandboxir::ResultReason::DiffOpcodes);
205   }
206   {
207     // Check NotConsecutive Ld0,Ld0b
208     const auto &Result =
209         Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
210     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
211     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
212               sandboxir::ResultReason::NotConsecutive);
213   }
214   {
215     // Check NotConsecutive Ld0,Ld3
216     const auto &Result =
217         Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
218     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
219     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
220               sandboxir::ResultReason::NotConsecutive);
221   }
222   {
223     // Check Widen Ld0,Ld1
224     const auto &Result =
225         Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
226     EXPECT_TRUE(isa<sandboxir::Widen>(Result));
227   }
228 }
229 
230 TEST_F(LegalityTest, LegalitySchedule) {
231   parseIR(C, R"IR(
232 define void @foo(ptr %ptr) {
233   %gep0 = getelementptr float, ptr %ptr, i32 0
234   %gep1 = getelementptr float, ptr %ptr, i32 1
235   %ld0 = load float, ptr %gep0
236   store float %ld0, ptr %gep1
237   %ld1 = load float, ptr %gep1
238   store float %ld0, ptr %gep0
239   store float %ld1, ptr %gep1
240   ret void
241 }
242 )IR");
243   llvm::Function *LLVMF = &*M->getFunction("foo");
244   getAnalyses(*LLVMF);
245   const auto &DL = M->getDataLayout();
246 
247   sandboxir::Context Ctx(C);
248   auto *F = Ctx.createFunction(LLVMF);
249   auto *BB = &*F->begin();
250   auto It = BB->begin();
251   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
252   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
253   auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
254   [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
255   auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
256   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
257   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
258 
259   llvm::sandboxir::InstrMaps IMaps(Ctx);
260   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
261   {
262     // Can vectorize St0,St1.
263     const auto &Result = Legality.canVectorize({St0, St1});
264     EXPECT_TRUE(isa<sandboxir::Widen>(Result));
265   }
266   {
267     // Can't vectorize Ld0,Ld1 because of conflicting store.
268     auto &Result = Legality.canVectorize({Ld0, Ld1});
269     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
270     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
271               sandboxir::ResultReason::CantSchedule);
272   }
273 }
274 
275 #ifndef NDEBUG
276 TEST_F(LegalityTest, LegalityResultDump) {
277   parseIR(C, R"IR(
278 define void @foo() {
279   ret void
280 }
281 )IR");
282   llvm::Function *LLVMF = &*M->getFunction("foo");
283   getAnalyses(*LLVMF);
284   const auto &DL = M->getDataLayout();
285 
286   auto Matches = [](const sandboxir::LegalityResult &Result,
287                     const std::string &ExpectedStr) -> bool {
288     std::string Buff;
289     raw_string_ostream OS(Buff);
290     Result.print(OS);
291     return Buff == ExpectedStr;
292   };
293 
294   sandboxir::Context Ctx(C);
295   llvm::sandboxir::InstrMaps IMaps(Ctx);
296   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
297   EXPECT_TRUE(
298       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
299   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
300                           sandboxir::ResultReason::NotInstructions),
301                       "Pack Reason: NotInstructions"));
302   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
303                           sandboxir::ResultReason::DiffOpcodes),
304                       "Pack Reason: DiffOpcodes"));
305   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
306                           sandboxir::ResultReason::DiffTypes),
307                       "Pack Reason: DiffTypes"));
308   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
309                           sandboxir::ResultReason::DiffMathFlags),
310                       "Pack Reason: DiffMathFlags"));
311   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
312                           sandboxir::ResultReason::DiffWrapFlags),
313                       "Pack Reason: DiffWrapFlags"));
314 }
315 #endif // NDEBUG
316 
317 TEST_F(LegalityTest, CollectDescr) {
318   parseIR(C, R"IR(
319 define void @foo(ptr %ptr) {
320   %gep0 = getelementptr float, ptr %ptr, i32 0
321   %gep1 = getelementptr float, ptr %ptr, i32 1
322   %ld0 = load float, ptr %gep0
323   %ld1 = load float, ptr %gep1
324   %vld = load <4 x float>, ptr %ptr
325   ret void
326 }
327 )IR");
328   llvm::Function *LLVMF = &*M->getFunction("foo");
329   getAnalyses(*LLVMF);
330   sandboxir::Context Ctx(C);
331   auto *F = Ctx.createFunction(LLVMF);
332   auto *BB = &*F->begin();
333   auto It = BB->begin();
334   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
335   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
336   auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
337   [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
338   auto *VLd = cast<sandboxir::LoadInst>(&*It++);
339 
340   sandboxir::CollectDescr::DescrVecT Descrs;
341   using EEDescr = sandboxir::CollectDescr::ExtractElementDescr;
342 
343   {
344     // Check single input, no shuffle.
345     Descrs.push_back(EEDescr(VLd, 0));
346     Descrs.push_back(EEDescr(VLd, 1));
347     sandboxir::CollectDescr CD(std::move(Descrs));
348     EXPECT_TRUE(CD.getSingleInput());
349     EXPECT_EQ(CD.getSingleInput()->first, VLd);
350     EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1));
351     EXPECT_TRUE(CD.hasVectorInputs());
352   }
353   {
354     // Check single input, shuffle.
355     Descrs.push_back(EEDescr(VLd, 1));
356     Descrs.push_back(EEDescr(VLd, 0));
357     sandboxir::CollectDescr CD(std::move(Descrs));
358     EXPECT_TRUE(CD.getSingleInput());
359     EXPECT_EQ(CD.getSingleInput()->first, VLd);
360     EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0));
361     EXPECT_TRUE(CD.hasVectorInputs());
362   }
363   {
364     // Check multiple inputs.
365     Descrs.push_back(EEDescr(Ld0));
366     Descrs.push_back(EEDescr(VLd, 0));
367     Descrs.push_back(EEDescr(VLd, 1));
368     sandboxir::CollectDescr CD(std::move(Descrs));
369     EXPECT_FALSE(CD.getSingleInput());
370     EXPECT_TRUE(CD.hasVectorInputs());
371   }
372   {
373     // Check multiple inputs only scalars.
374     Descrs.push_back(EEDescr(Ld0));
375     Descrs.push_back(EEDescr(Ld1));
376     sandboxir::CollectDescr CD(std::move(Descrs));
377     EXPECT_FALSE(CD.getSingleInput());
378     EXPECT_FALSE(CD.hasVectorInputs());
379   }
380 }
381 
382 TEST_F(LegalityTest, ShuffleMask) {
383   {
384     // Check SmallVector constructor.
385     SmallVector<int> Indices({0, 1, 2, 3});
386     sandboxir::ShuffleMask Mask(std::move(Indices));
387     EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
388   }
389   {
390     // Check initializer_list constructor.
391     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
392     EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
393   }
394   {
395     // Check ArrayRef constructor.
396     sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3}));
397     EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
398   }
399   {
400     // Check operator ArrayRef<int>().
401     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
402     ArrayRef<int> Array = Mask;
403     EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3));
404   }
405   {
406     // Check getIdentity().
407     auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4);
408     EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3));
409     EXPECT_TRUE(IdentityMask.isIdentity());
410   }
411   {
412     // Check isIdentity().
413     sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
414     EXPECT_TRUE(Mask1.isIdentity());
415     sandboxir::ShuffleMask Mask2({1, 2, 3, 4});
416     EXPECT_FALSE(Mask2.isIdentity());
417   }
418   {
419     // Check operator==().
420     sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
421     sandboxir::ShuffleMask Mask2({0, 1, 2, 3});
422     EXPECT_TRUE(Mask1 == Mask2);
423     EXPECT_FALSE(Mask1 != Mask2);
424   }
425   {
426     // Check operator!=().
427     sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
428     sandboxir::ShuffleMask Mask2({0, 1, 2, 4});
429     EXPECT_TRUE(Mask1 != Mask2);
430     EXPECT_FALSE(Mask1 == Mask2);
431   }
432   {
433     // Check size().
434     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
435     EXPECT_EQ(Mask.size(), 4u);
436   }
437   {
438     // Check operator[].
439     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
440     for (auto [Idx, Elm] : enumerate(Mask)) {
441       EXPECT_EQ(Elm, Mask[Idx]);
442     }
443   }
444   {
445     // Check begin(), end().
446     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
447     sandboxir::ShuffleMask::const_iterator Begin = Mask.begin();
448     sandboxir::ShuffleMask::const_iterator End = Mask.begin();
449     int Idx = 0;
450     for (auto It = Begin; It != End; ++It) {
451       EXPECT_EQ(*It, Mask[Idx++]);
452     }
453   }
454 #ifndef NDEBUG
455   {
456     // Check print(OS).
457     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
458     std::string Str;
459     raw_string_ostream OS(Str);
460     Mask.print(OS);
461     EXPECT_EQ(Str, "0,1,2,3");
462   }
463   {
464     // Check operator<<().
465     sandboxir::ShuffleMask Mask({0, 1, 2, 3});
466     std::string Str;
467     raw_string_ostream OS(Str);
468     OS << Mask;
469     EXPECT_EQ(Str, "0,1,2,3");
470   }
471 #endif // NDEBUG
472 }
473