xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (revision 6409799bdcd86be3ed72e8d172181294d3e5ad09)
154c93aabSvporpo //===- Legality.cpp -------------------------------------------------------===//
254c93aabSvporpo //
354c93aabSvporpo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
454c93aabSvporpo // See https://llvm.org/LICENSE.txt for license information.
554c93aabSvporpo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
654c93aabSvporpo //
754c93aabSvporpo //===----------------------------------------------------------------------===//
854c93aabSvporpo 
954c93aabSvporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
101540f772SVasileios Porpodas #include "llvm/SandboxIR/Instruction.h"
11bf4b31adSvporpo #include "llvm/SandboxIR/Operator.h"
121540f772SVasileios Porpodas #include "llvm/SandboxIR/Utils.h"
1354c93aabSvporpo #include "llvm/SandboxIR/Value.h"
1454c93aabSvporpo #include "llvm/Support/Debug.h"
15e902c696Svporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
165ea69481Svporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1754c93aabSvporpo 
1854c93aabSvporpo namespace llvm::sandboxir {
1954c93aabSvporpo 
201540f772SVasileios Porpodas #define DEBUG_TYPE "SBVec:Legality"
211540f772SVasileios Porpodas 
2254c93aabSvporpo #ifndef NDEBUG
2387e4b681Svporpo void ShuffleMask::dump() const {
2487e4b681Svporpo   print(dbgs());
2587e4b681Svporpo   dbgs() << "\n";
2687e4b681Svporpo }
2787e4b681Svporpo 
2854c93aabSvporpo void LegalityResult::dump() const {
2954c93aabSvporpo   print(dbgs());
3054c93aabSvporpo   dbgs() << "\n";
3154c93aabSvporpo }
3254c93aabSvporpo #endif // NDEBUG
3354c93aabSvporpo 
3454c93aabSvporpo std::optional<ResultReason>
3554c93aabSvporpo LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
3654c93aabSvporpo     ArrayRef<Value *> Bndl) {
375ea69481Svporpo   auto *I0 = cast<Instruction>(Bndl[0]);
385ea69481Svporpo   auto Opcode = I0->getOpcode();
395ea69481Svporpo   // If they have different opcodes, then we cannot form a vector (for now).
405ea69481Svporpo   if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
415ea69481Svporpo         return cast<Instruction>(V)->getOpcode() != Opcode;
425ea69481Svporpo       }))
435ea69481Svporpo     return ResultReason::DiffOpcodes;
445ea69481Svporpo 
455ea69481Svporpo   // If not the same scalar type, Pack. This will accept scalars and vectors as
465ea69481Svporpo   // long as the element type is the same.
475ea69481Svporpo   Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
485ea69481Svporpo   if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) {
495ea69481Svporpo         return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0;
505ea69481Svporpo       }))
515ea69481Svporpo     return ResultReason::DiffTypes;
525ea69481Svporpo 
53bf4b31adSvporpo   // TODO: Allow vectorization of instrs with different flags as long as we
54bf4b31adSvporpo   // change them to the least common one.
55bf4b31adSvporpo   // For now pack if differnt FastMathFlags.
56bf4b31adSvporpo   if (isa<FPMathOperator>(I0)) {
57bf4b31adSvporpo     FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
58bf4b31adSvporpo     if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
59bf4b31adSvporpo           return cast<Instruction>(V)->getFastMathFlags() != FMF0;
60bf4b31adSvporpo         }))
61bf4b31adSvporpo       return ResultReason::DiffMathFlags;
62bf4b31adSvporpo   }
63bf4b31adSvporpo 
64ca998b07Svporpo   // TODO: Allow vectorization by using common flags.
65ca998b07Svporpo   // For now Pack if they don't have the same wrap flags.
66ca998b07Svporpo   bool CanHaveWrapFlags =
67ca998b07Svporpo       isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0);
68ca998b07Svporpo   if (CanHaveWrapFlags) {
69ca998b07Svporpo     bool NUW0 = I0->hasNoUnsignedWrap();
70ca998b07Svporpo     bool NSW0 = I0->hasNoSignedWrap();
71ca998b07Svporpo     if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) {
72ca998b07Svporpo           return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
73ca998b07Svporpo                  cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
74ca998b07Svporpo         })) {
75ca998b07Svporpo       return ResultReason::DiffWrapFlags;
76ca998b07Svporpo     }
77ca998b07Svporpo   }
78ca998b07Svporpo 
79083369fdSvporpo   // Now we need to do further checks for specific opcodes.
80083369fdSvporpo   switch (Opcode) {
81083369fdSvporpo   case Instruction::Opcode::ZExt:
82083369fdSvporpo   case Instruction::Opcode::SExt:
83083369fdSvporpo   case Instruction::Opcode::FPToUI:
84083369fdSvporpo   case Instruction::Opcode::FPToSI:
85083369fdSvporpo   case Instruction::Opcode::FPExt:
86083369fdSvporpo   case Instruction::Opcode::PtrToInt:
87083369fdSvporpo   case Instruction::Opcode::IntToPtr:
88083369fdSvporpo   case Instruction::Opcode::SIToFP:
89083369fdSvporpo   case Instruction::Opcode::UIToFP:
90083369fdSvporpo   case Instruction::Opcode::Trunc:
91083369fdSvporpo   case Instruction::Opcode::FPTrunc:
92083369fdSvporpo   case Instruction::Opcode::BitCast: {
93083369fdSvporpo     // We have already checked that they are of the same opcode.
94083369fdSvporpo     assert(all_of(Bndl,
95083369fdSvporpo                   [Opcode](Value *V) {
96083369fdSvporpo                     return cast<Instruction>(V)->getOpcode() == Opcode;
97083369fdSvporpo                   }) &&
98083369fdSvporpo            "Different opcodes, should have early returned!");
99083369fdSvporpo     // But for these opcodes we should also check the operand type.
100083369fdSvporpo     Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
101083369fdSvporpo     if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
102083369fdSvporpo           return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
103083369fdSvporpo                  FromTy0;
104083369fdSvporpo         }))
105083369fdSvporpo       return ResultReason::DiffTypes;
106083369fdSvporpo     return std::nullopt;
107083369fdSvporpo   }
108083369fdSvporpo   case Instruction::Opcode::FCmp:
109083369fdSvporpo   case Instruction::Opcode::ICmp: {
110083369fdSvporpo     // We need the same predicate..
111083369fdSvporpo     auto Pred0 = cast<CmpInst>(I0)->getPredicate();
112083369fdSvporpo     bool Same = all_of(Bndl, [Pred0](Value *V) {
113083369fdSvporpo       return cast<CmpInst>(V)->getPredicate() == Pred0;
114083369fdSvporpo     });
115083369fdSvporpo     if (Same)
116083369fdSvporpo       return std::nullopt;
117083369fdSvporpo     return ResultReason::DiffOpcodes;
118083369fdSvporpo   }
119083369fdSvporpo   case Instruction::Opcode::Select:
120083369fdSvporpo   case Instruction::Opcode::FNeg:
121083369fdSvporpo   case Instruction::Opcode::Add:
122083369fdSvporpo   case Instruction::Opcode::FAdd:
123083369fdSvporpo   case Instruction::Opcode::Sub:
124083369fdSvporpo   case Instruction::Opcode::FSub:
125083369fdSvporpo   case Instruction::Opcode::Mul:
126083369fdSvporpo   case Instruction::Opcode::FMul:
127083369fdSvporpo   case Instruction::Opcode::FRem:
128083369fdSvporpo   case Instruction::Opcode::UDiv:
129083369fdSvporpo   case Instruction::Opcode::SDiv:
130083369fdSvporpo   case Instruction::Opcode::FDiv:
131083369fdSvporpo   case Instruction::Opcode::URem:
132083369fdSvporpo   case Instruction::Opcode::SRem:
133083369fdSvporpo   case Instruction::Opcode::Shl:
134083369fdSvporpo   case Instruction::Opcode::LShr:
135083369fdSvporpo   case Instruction::Opcode::AShr:
136083369fdSvporpo   case Instruction::Opcode::And:
137083369fdSvporpo   case Instruction::Opcode::Or:
138083369fdSvporpo   case Instruction::Opcode::Xor:
139083369fdSvporpo     return std::nullopt;
140083369fdSvporpo   case Instruction::Opcode::Load:
141083369fdSvporpo     if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
142083369fdSvporpo       return std::nullopt;
143083369fdSvporpo     return ResultReason::NotConsecutive;
144083369fdSvporpo   case Instruction::Opcode::Store:
145083369fdSvporpo     if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
146083369fdSvporpo       return std::nullopt;
147083369fdSvporpo     return ResultReason::NotConsecutive;
148083369fdSvporpo   case Instruction::Opcode::PHI:
149083369fdSvporpo     return ResultReason::Unimplemented;
150083369fdSvporpo   case Instruction::Opcode::Opaque:
151083369fdSvporpo     return ResultReason::Unimplemented;
152083369fdSvporpo   case Instruction::Opcode::Br:
153083369fdSvporpo   case Instruction::Opcode::Ret:
154083369fdSvporpo   case Instruction::Opcode::AddrSpaceCast:
155083369fdSvporpo   case Instruction::Opcode::InsertElement:
156083369fdSvporpo   case Instruction::Opcode::InsertValue:
157083369fdSvporpo   case Instruction::Opcode::ExtractElement:
158083369fdSvporpo   case Instruction::Opcode::ExtractValue:
159083369fdSvporpo   case Instruction::Opcode::ShuffleVector:
160083369fdSvporpo   case Instruction::Opcode::Call:
161083369fdSvporpo   case Instruction::Opcode::GetElementPtr:
162083369fdSvporpo   case Instruction::Opcode::Switch:
163083369fdSvporpo     return ResultReason::Unimplemented;
164083369fdSvporpo   case Instruction::Opcode::VAArg:
165083369fdSvporpo   case Instruction::Opcode::Freeze:
166083369fdSvporpo   case Instruction::Opcode::Fence:
167083369fdSvporpo   case Instruction::Opcode::Invoke:
168083369fdSvporpo   case Instruction::Opcode::CallBr:
169083369fdSvporpo   case Instruction::Opcode::LandingPad:
170083369fdSvporpo   case Instruction::Opcode::CatchPad:
171083369fdSvporpo   case Instruction::Opcode::CleanupPad:
172083369fdSvporpo   case Instruction::Opcode::CatchRet:
173083369fdSvporpo   case Instruction::Opcode::CleanupRet:
174083369fdSvporpo   case Instruction::Opcode::Resume:
175083369fdSvporpo   case Instruction::Opcode::CatchSwitch:
176083369fdSvporpo   case Instruction::Opcode::AtomicRMW:
177083369fdSvporpo   case Instruction::Opcode::AtomicCmpXchg:
178083369fdSvporpo   case Instruction::Opcode::Alloca:
179083369fdSvporpo   case Instruction::Opcode::Unreachable:
180083369fdSvporpo     return ResultReason::Infeasible;
181083369fdSvporpo   }
1825ea69481Svporpo 
18354c93aabSvporpo   return std::nullopt;
18454c93aabSvporpo }
18554c93aabSvporpo 
1861540f772SVasileios Porpodas #ifndef NDEBUG
1871540f772SVasileios Porpodas static void dumpBndl(ArrayRef<Value *> Bndl) {
1881540f772SVasileios Porpodas   for (auto *V : Bndl)
1891540f772SVasileios Porpodas     dbgs() << *V << "\n";
1901540f772SVasileios Porpodas }
1911540f772SVasileios Porpodas #endif // NDEBUG
1921540f772SVasileios Porpodas 
193e902c696Svporpo CollectDescr
194e902c696Svporpo LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
195e902c696Svporpo   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
196e902c696Svporpo   Vec.reserve(Bndl.size());
197e902c696Svporpo   for (auto [Lane, V] : enumerate(Bndl)) {
198e902c696Svporpo     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
199e902c696Svporpo       // If there is a vector containing `V`, then get the lane it came from.
200e902c696Svporpo       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
201e902c696Svporpo       Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
202e902c696Svporpo     } else {
203e902c696Svporpo       Vec.emplace_back(V);
204e902c696Svporpo     }
205e902c696Svporpo   }
206e902c696Svporpo   return CollectDescr(std::move(Vec));
207e902c696Svporpo }
208e902c696Svporpo 
209ce0d0858Svporpo const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
210ce0d0858Svporpo                                                      bool SkipScheduling) {
2111540f772SVasileios Porpodas   // If Bndl contains values other than instructions, we need to Pack.
2121540f772SVasileios Porpodas   if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
2131540f772SVasileios Porpodas     LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
2141540f772SVasileios Porpodas                dumpBndl(Bndl););
2151540f772SVasileios Porpodas     return createLegalityResult<Pack>(ResultReason::NotInstructions);
2161540f772SVasileios Porpodas   }
217*6409799bSvporpo   // Pack if not in the same BB.
218*6409799bSvporpo   auto *BB = cast<Instruction>(Bndl[0])->getParent();
219*6409799bSvporpo   if (any_of(drop_begin(Bndl),
220*6409799bSvporpo              [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; }))
221*6409799bSvporpo     return createLegalityResult<Pack>(ResultReason::DiffBBs);
2221540f772SVasileios Porpodas 
223e902c696Svporpo   auto CollectDescrs = getHowToCollectValues(Bndl);
224e902c696Svporpo   if (CollectDescrs.hasVectorInputs()) {
225e902c696Svporpo     if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
22687e4b681Svporpo       auto [Vec, Mask] = *ValueShuffleOpt;
22787e4b681Svporpo       if (Mask.isIdentity())
228e902c696Svporpo         return createLegalityResult<DiamondReuse>(Vec);
22987e4b681Svporpo       return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
230e902c696Svporpo     }
231fd087135Svporpo     return createLegalityResult<DiamondReuseMultiInput>(
232fd087135Svporpo         std::move(CollectDescrs));
233e902c696Svporpo   }
234e902c696Svporpo 
23554c93aabSvporpo   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
23654c93aabSvporpo     return createLegalityResult<Pack>(*ReasonOpt);
23754c93aabSvporpo 
238ce0d0858Svporpo   if (!SkipScheduling) {
239ce0d0858Svporpo     // TODO: Try to remove the IBndl vector.
240ce0d0858Svporpo     SmallVector<Instruction *, 8> IBndl;
241ce0d0858Svporpo     IBndl.reserve(Bndl.size());
242ce0d0858Svporpo     for (auto *V : Bndl)
243ce0d0858Svporpo       IBndl.push_back(cast<Instruction>(V));
244ce0d0858Svporpo     if (!Sched.trySchedule(IBndl))
245ce0d0858Svporpo       return createLegalityResult<Pack>(ResultReason::CantSchedule);
246ce0d0858Svporpo   }
24754c93aabSvporpo 
24854c93aabSvporpo   return createLegalityResult<Widen>();
24954c93aabSvporpo }
250e902c696Svporpo 
251e902c696Svporpo void LegalityAnalysis::clear() {
252e902c696Svporpo   Sched.clear();
253e902c696Svporpo   IMaps.clear();
254e902c696Svporpo }
25554c93aabSvporpo } // namespace llvm::sandboxir
256