xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (revision e902c6960cff4372d4b3ef9ae424b24ec6b0ea38)
1 //===- Legality.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10 #include "llvm/SandboxIR/Instruction.h"
11 #include "llvm/SandboxIR/Operator.h"
12 #include "llvm/SandboxIR/Utils.h"
13 #include "llvm/SandboxIR/Value.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17 
18 namespace llvm::sandboxir {
19 
20 #define DEBUG_TYPE "SBVec:Legality"
21 
22 #ifndef NDEBUG
23 void LegalityResult::dump() const {
24   print(dbgs());
25   dbgs() << "\n";
26 }
27 #endif // NDEBUG
28 
29 std::optional<ResultReason>
30 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
31     ArrayRef<Value *> Bndl) {
32   auto *I0 = cast<Instruction>(Bndl[0]);
33   auto Opcode = I0->getOpcode();
34   // If they have different opcodes, then we cannot form a vector (for now).
35   if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
36         return cast<Instruction>(V)->getOpcode() != Opcode;
37       }))
38     return ResultReason::DiffOpcodes;
39 
40   // If not the same scalar type, Pack. This will accept scalars and vectors as
41   // long as the element type is the same.
42   Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
43   if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) {
44         return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0;
45       }))
46     return ResultReason::DiffTypes;
47 
48   // TODO: Allow vectorization of instrs with different flags as long as we
49   // change them to the least common one.
50   // For now pack if differnt FastMathFlags.
51   if (isa<FPMathOperator>(I0)) {
52     FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
53     if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
54           return cast<Instruction>(V)->getFastMathFlags() != FMF0;
55         }))
56       return ResultReason::DiffMathFlags;
57   }
58 
59   // TODO: Allow vectorization by using common flags.
60   // For now Pack if they don't have the same wrap flags.
61   bool CanHaveWrapFlags =
62       isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0);
63   if (CanHaveWrapFlags) {
64     bool NUW0 = I0->hasNoUnsignedWrap();
65     bool NSW0 = I0->hasNoSignedWrap();
66     if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) {
67           return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
68                  cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
69         })) {
70       return ResultReason::DiffWrapFlags;
71     }
72   }
73 
74   // Now we need to do further checks for specific opcodes.
75   switch (Opcode) {
76   case Instruction::Opcode::ZExt:
77   case Instruction::Opcode::SExt:
78   case Instruction::Opcode::FPToUI:
79   case Instruction::Opcode::FPToSI:
80   case Instruction::Opcode::FPExt:
81   case Instruction::Opcode::PtrToInt:
82   case Instruction::Opcode::IntToPtr:
83   case Instruction::Opcode::SIToFP:
84   case Instruction::Opcode::UIToFP:
85   case Instruction::Opcode::Trunc:
86   case Instruction::Opcode::FPTrunc:
87   case Instruction::Opcode::BitCast: {
88     // We have already checked that they are of the same opcode.
89     assert(all_of(Bndl,
90                   [Opcode](Value *V) {
91                     return cast<Instruction>(V)->getOpcode() == Opcode;
92                   }) &&
93            "Different opcodes, should have early returned!");
94     // But for these opcodes we should also check the operand type.
95     Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
96     if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
97           return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
98                  FromTy0;
99         }))
100       return ResultReason::DiffTypes;
101     return std::nullopt;
102   }
103   case Instruction::Opcode::FCmp:
104   case Instruction::Opcode::ICmp: {
105     // We need the same predicate..
106     auto Pred0 = cast<CmpInst>(I0)->getPredicate();
107     bool Same = all_of(Bndl, [Pred0](Value *V) {
108       return cast<CmpInst>(V)->getPredicate() == Pred0;
109     });
110     if (Same)
111       return std::nullopt;
112     return ResultReason::DiffOpcodes;
113   }
114   case Instruction::Opcode::Select:
115   case Instruction::Opcode::FNeg:
116   case Instruction::Opcode::Add:
117   case Instruction::Opcode::FAdd:
118   case Instruction::Opcode::Sub:
119   case Instruction::Opcode::FSub:
120   case Instruction::Opcode::Mul:
121   case Instruction::Opcode::FMul:
122   case Instruction::Opcode::FRem:
123   case Instruction::Opcode::UDiv:
124   case Instruction::Opcode::SDiv:
125   case Instruction::Opcode::FDiv:
126   case Instruction::Opcode::URem:
127   case Instruction::Opcode::SRem:
128   case Instruction::Opcode::Shl:
129   case Instruction::Opcode::LShr:
130   case Instruction::Opcode::AShr:
131   case Instruction::Opcode::And:
132   case Instruction::Opcode::Or:
133   case Instruction::Opcode::Xor:
134     return std::nullopt;
135   case Instruction::Opcode::Load:
136     if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
137       return std::nullopt;
138     return ResultReason::NotConsecutive;
139   case Instruction::Opcode::Store:
140     if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
141       return std::nullopt;
142     return ResultReason::NotConsecutive;
143   case Instruction::Opcode::PHI:
144     return ResultReason::Unimplemented;
145   case Instruction::Opcode::Opaque:
146     return ResultReason::Unimplemented;
147   case Instruction::Opcode::Br:
148   case Instruction::Opcode::Ret:
149   case Instruction::Opcode::AddrSpaceCast:
150   case Instruction::Opcode::InsertElement:
151   case Instruction::Opcode::InsertValue:
152   case Instruction::Opcode::ExtractElement:
153   case Instruction::Opcode::ExtractValue:
154   case Instruction::Opcode::ShuffleVector:
155   case Instruction::Opcode::Call:
156   case Instruction::Opcode::GetElementPtr:
157   case Instruction::Opcode::Switch:
158     return ResultReason::Unimplemented;
159   case Instruction::Opcode::VAArg:
160   case Instruction::Opcode::Freeze:
161   case Instruction::Opcode::Fence:
162   case Instruction::Opcode::Invoke:
163   case Instruction::Opcode::CallBr:
164   case Instruction::Opcode::LandingPad:
165   case Instruction::Opcode::CatchPad:
166   case Instruction::Opcode::CleanupPad:
167   case Instruction::Opcode::CatchRet:
168   case Instruction::Opcode::CleanupRet:
169   case Instruction::Opcode::Resume:
170   case Instruction::Opcode::CatchSwitch:
171   case Instruction::Opcode::AtomicRMW:
172   case Instruction::Opcode::AtomicCmpXchg:
173   case Instruction::Opcode::Alloca:
174   case Instruction::Opcode::Unreachable:
175     return ResultReason::Infeasible;
176   }
177 
178   return std::nullopt;
179 }
180 
181 #ifndef NDEBUG
182 static void dumpBndl(ArrayRef<Value *> Bndl) {
183   for (auto *V : Bndl)
184     dbgs() << *V << "\n";
185 }
186 #endif // NDEBUG
187 
188 CollectDescr
189 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
190   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
191   Vec.reserve(Bndl.size());
192   for (auto [Lane, V] : enumerate(Bndl)) {
193     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
194       // If there is a vector containing `V`, then get the lane it came from.
195       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
196       Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
197     } else {
198       Vec.emplace_back(V);
199     }
200   }
201   return CollectDescr(std::move(Vec));
202 }
203 
204 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
205                                                      bool SkipScheduling) {
206   // If Bndl contains values other than instructions, we need to Pack.
207   if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
208     LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
209                dumpBndl(Bndl););
210     return createLegalityResult<Pack>(ResultReason::NotInstructions);
211   }
212 
213   auto CollectDescrs = getHowToCollectValues(Bndl);
214   if (CollectDescrs.hasVectorInputs()) {
215     if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
216       auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
217       if (!NeedsShuffle)
218         return createLegalityResult<DiamondReuse>(Vec);
219       llvm_unreachable("TODO: Unimplemented");
220     } else {
221       llvm_unreachable("TODO: Unimplemented");
222     }
223   }
224 
225   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
226     return createLegalityResult<Pack>(*ReasonOpt);
227 
228   if (!SkipScheduling) {
229     // TODO: Try to remove the IBndl vector.
230     SmallVector<Instruction *, 8> IBndl;
231     IBndl.reserve(Bndl.size());
232     for (auto *V : Bndl)
233       IBndl.push_back(cast<Instruction>(V));
234     if (!Sched.trySchedule(IBndl))
235       return createLegalityResult<Pack>(ResultReason::CantSchedule);
236   }
237 
238   return createLegalityResult<Widen>();
239 }
240 
241 void LegalityAnalysis::clear() {
242   Sched.clear();
243   IMaps.clear();
244 }
245 } // namespace llvm::sandboxir
246