xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (revision 6409799bdcd86be3ed72e8d172181294d3e5ad09)
1 //===- Legality.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10 #include "llvm/SandboxIR/Instruction.h"
11 #include "llvm/SandboxIR/Operator.h"
12 #include "llvm/SandboxIR/Utils.h"
13 #include "llvm/SandboxIR/Value.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17 
18 namespace llvm::sandboxir {
19 
20 #define DEBUG_TYPE "SBVec:Legality"
21 
22 #ifndef NDEBUG
23 void ShuffleMask::dump() const {
24   print(dbgs());
25   dbgs() << "\n";
26 }
27 
28 void LegalityResult::dump() const {
29   print(dbgs());
30   dbgs() << "\n";
31 }
32 #endif // NDEBUG
33 
34 std::optional<ResultReason>
35 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
36     ArrayRef<Value *> Bndl) {
37   auto *I0 = cast<Instruction>(Bndl[0]);
38   auto Opcode = I0->getOpcode();
39   // If they have different opcodes, then we cannot form a vector (for now).
40   if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
41         return cast<Instruction>(V)->getOpcode() != Opcode;
42       }))
43     return ResultReason::DiffOpcodes;
44 
45   // If not the same scalar type, Pack. This will accept scalars and vectors as
46   // long as the element type is the same.
47   Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
48   if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) {
49         return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0;
50       }))
51     return ResultReason::DiffTypes;
52 
53   // TODO: Allow vectorization of instrs with different flags as long as we
54   // change them to the least common one.
55   // For now pack if differnt FastMathFlags.
56   if (isa<FPMathOperator>(I0)) {
57     FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
58     if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
59           return cast<Instruction>(V)->getFastMathFlags() != FMF0;
60         }))
61       return ResultReason::DiffMathFlags;
62   }
63 
64   // TODO: Allow vectorization by using common flags.
65   // For now Pack if they don't have the same wrap flags.
66   bool CanHaveWrapFlags =
67       isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0);
68   if (CanHaveWrapFlags) {
69     bool NUW0 = I0->hasNoUnsignedWrap();
70     bool NSW0 = I0->hasNoSignedWrap();
71     if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) {
72           return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
73                  cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
74         })) {
75       return ResultReason::DiffWrapFlags;
76     }
77   }
78 
79   // Now we need to do further checks for specific opcodes.
80   switch (Opcode) {
81   case Instruction::Opcode::ZExt:
82   case Instruction::Opcode::SExt:
83   case Instruction::Opcode::FPToUI:
84   case Instruction::Opcode::FPToSI:
85   case Instruction::Opcode::FPExt:
86   case Instruction::Opcode::PtrToInt:
87   case Instruction::Opcode::IntToPtr:
88   case Instruction::Opcode::SIToFP:
89   case Instruction::Opcode::UIToFP:
90   case Instruction::Opcode::Trunc:
91   case Instruction::Opcode::FPTrunc:
92   case Instruction::Opcode::BitCast: {
93     // We have already checked that they are of the same opcode.
94     assert(all_of(Bndl,
95                   [Opcode](Value *V) {
96                     return cast<Instruction>(V)->getOpcode() == Opcode;
97                   }) &&
98            "Different opcodes, should have early returned!");
99     // But for these opcodes we should also check the operand type.
100     Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
101     if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
102           return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
103                  FromTy0;
104         }))
105       return ResultReason::DiffTypes;
106     return std::nullopt;
107   }
108   case Instruction::Opcode::FCmp:
109   case Instruction::Opcode::ICmp: {
110     // We need the same predicate..
111     auto Pred0 = cast<CmpInst>(I0)->getPredicate();
112     bool Same = all_of(Bndl, [Pred0](Value *V) {
113       return cast<CmpInst>(V)->getPredicate() == Pred0;
114     });
115     if (Same)
116       return std::nullopt;
117     return ResultReason::DiffOpcodes;
118   }
119   case Instruction::Opcode::Select:
120   case Instruction::Opcode::FNeg:
121   case Instruction::Opcode::Add:
122   case Instruction::Opcode::FAdd:
123   case Instruction::Opcode::Sub:
124   case Instruction::Opcode::FSub:
125   case Instruction::Opcode::Mul:
126   case Instruction::Opcode::FMul:
127   case Instruction::Opcode::FRem:
128   case Instruction::Opcode::UDiv:
129   case Instruction::Opcode::SDiv:
130   case Instruction::Opcode::FDiv:
131   case Instruction::Opcode::URem:
132   case Instruction::Opcode::SRem:
133   case Instruction::Opcode::Shl:
134   case Instruction::Opcode::LShr:
135   case Instruction::Opcode::AShr:
136   case Instruction::Opcode::And:
137   case Instruction::Opcode::Or:
138   case Instruction::Opcode::Xor:
139     return std::nullopt;
140   case Instruction::Opcode::Load:
141     if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
142       return std::nullopt;
143     return ResultReason::NotConsecutive;
144   case Instruction::Opcode::Store:
145     if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
146       return std::nullopt;
147     return ResultReason::NotConsecutive;
148   case Instruction::Opcode::PHI:
149     return ResultReason::Unimplemented;
150   case Instruction::Opcode::Opaque:
151     return ResultReason::Unimplemented;
152   case Instruction::Opcode::Br:
153   case Instruction::Opcode::Ret:
154   case Instruction::Opcode::AddrSpaceCast:
155   case Instruction::Opcode::InsertElement:
156   case Instruction::Opcode::InsertValue:
157   case Instruction::Opcode::ExtractElement:
158   case Instruction::Opcode::ExtractValue:
159   case Instruction::Opcode::ShuffleVector:
160   case Instruction::Opcode::Call:
161   case Instruction::Opcode::GetElementPtr:
162   case Instruction::Opcode::Switch:
163     return ResultReason::Unimplemented;
164   case Instruction::Opcode::VAArg:
165   case Instruction::Opcode::Freeze:
166   case Instruction::Opcode::Fence:
167   case Instruction::Opcode::Invoke:
168   case Instruction::Opcode::CallBr:
169   case Instruction::Opcode::LandingPad:
170   case Instruction::Opcode::CatchPad:
171   case Instruction::Opcode::CleanupPad:
172   case Instruction::Opcode::CatchRet:
173   case Instruction::Opcode::CleanupRet:
174   case Instruction::Opcode::Resume:
175   case Instruction::Opcode::CatchSwitch:
176   case Instruction::Opcode::AtomicRMW:
177   case Instruction::Opcode::AtomicCmpXchg:
178   case Instruction::Opcode::Alloca:
179   case Instruction::Opcode::Unreachable:
180     return ResultReason::Infeasible;
181   }
182 
183   return std::nullopt;
184 }
185 
186 #ifndef NDEBUG
187 static void dumpBndl(ArrayRef<Value *> Bndl) {
188   for (auto *V : Bndl)
189     dbgs() << *V << "\n";
190 }
191 #endif // NDEBUG
192 
193 CollectDescr
194 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
195   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
196   Vec.reserve(Bndl.size());
197   for (auto [Lane, V] : enumerate(Bndl)) {
198     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
199       // If there is a vector containing `V`, then get the lane it came from.
200       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
201       Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
202     } else {
203       Vec.emplace_back(V);
204     }
205   }
206   return CollectDescr(std::move(Vec));
207 }
208 
209 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
210                                                      bool SkipScheduling) {
211   // If Bndl contains values other than instructions, we need to Pack.
212   if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
213     LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
214                dumpBndl(Bndl););
215     return createLegalityResult<Pack>(ResultReason::NotInstructions);
216   }
217   // Pack if not in the same BB.
218   auto *BB = cast<Instruction>(Bndl[0])->getParent();
219   if (any_of(drop_begin(Bndl),
220              [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; }))
221     return createLegalityResult<Pack>(ResultReason::DiffBBs);
222 
223   auto CollectDescrs = getHowToCollectValues(Bndl);
224   if (CollectDescrs.hasVectorInputs()) {
225     if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
226       auto [Vec, Mask] = *ValueShuffleOpt;
227       if (Mask.isIdentity())
228         return createLegalityResult<DiamondReuse>(Vec);
229       return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
230     }
231     return createLegalityResult<DiamondReuseMultiInput>(
232         std::move(CollectDescrs));
233   }
234 
235   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
236     return createLegalityResult<Pack>(*ReasonOpt);
237 
238   if (!SkipScheduling) {
239     // TODO: Try to remove the IBndl vector.
240     SmallVector<Instruction *, 8> IBndl;
241     IBndl.reserve(Bndl.size());
242     for (auto *V : Bndl)
243       IBndl.push_back(cast<Instruction>(V));
244     if (!Sched.trySchedule(IBndl))
245       return createLegalityResult<Pack>(ResultReason::CantSchedule);
246   }
247 
248   return createLegalityResult<Widen>();
249 }
250 
251 void LegalityAnalysis::clear() {
252   Sched.clear();
253   IMaps.clear();
254 }
255 } // namespace llvm::sandboxir
256