154c93aabSvporpo //===- Legality.cpp -------------------------------------------------------===// 254c93aabSvporpo // 354c93aabSvporpo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 454c93aabSvporpo // See https://llvm.org/LICENSE.txt for license information. 554c93aabSvporpo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 654c93aabSvporpo // 754c93aabSvporpo //===----------------------------------------------------------------------===// 854c93aabSvporpo 954c93aabSvporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" 101540f772SVasileios Porpodas #include "llvm/SandboxIR/Instruction.h" 11bf4b31adSvporpo #include "llvm/SandboxIR/Operator.h" 121540f772SVasileios Porpodas #include "llvm/SandboxIR/Utils.h" 1354c93aabSvporpo #include "llvm/SandboxIR/Value.h" 1454c93aabSvporpo #include "llvm/Support/Debug.h" 15e902c696Svporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" 165ea69481Svporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 1754c93aabSvporpo 1854c93aabSvporpo namespace llvm::sandboxir { 1954c93aabSvporpo 201540f772SVasileios Porpodas #define DEBUG_TYPE "SBVec:Legality" 211540f772SVasileios Porpodas 2254c93aabSvporpo #ifndef NDEBUG 2387e4b681Svporpo void ShuffleMask::dump() const { 2487e4b681Svporpo print(dbgs()); 2587e4b681Svporpo dbgs() << "\n"; 2687e4b681Svporpo } 2787e4b681Svporpo 2854c93aabSvporpo void LegalityResult::dump() const { 2954c93aabSvporpo print(dbgs()); 3054c93aabSvporpo dbgs() << "\n"; 3154c93aabSvporpo } 3254c93aabSvporpo #endif // NDEBUG 3354c93aabSvporpo 3454c93aabSvporpo std::optional<ResultReason> 3554c93aabSvporpo LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( 3654c93aabSvporpo ArrayRef<Value *> Bndl) { 375ea69481Svporpo auto *I0 = cast<Instruction>(Bndl[0]); 385ea69481Svporpo auto Opcode = I0->getOpcode(); 395ea69481Svporpo // If they have different opcodes, then we cannot form a vector (for now). 405ea69481Svporpo if (any_of(drop_begin(Bndl), [Opcode](Value *V) { 415ea69481Svporpo return cast<Instruction>(V)->getOpcode() != Opcode; 425ea69481Svporpo })) 435ea69481Svporpo return ResultReason::DiffOpcodes; 445ea69481Svporpo 455ea69481Svporpo // If not the same scalar type, Pack. This will accept scalars and vectors as 465ea69481Svporpo // long as the element type is the same. 475ea69481Svporpo Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0)); 485ea69481Svporpo if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) { 495ea69481Svporpo return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0; 505ea69481Svporpo })) 515ea69481Svporpo return ResultReason::DiffTypes; 525ea69481Svporpo 53bf4b31adSvporpo // TODO: Allow vectorization of instrs with different flags as long as we 54bf4b31adSvporpo // change them to the least common one. 55bf4b31adSvporpo // For now pack if differnt FastMathFlags. 56bf4b31adSvporpo if (isa<FPMathOperator>(I0)) { 57bf4b31adSvporpo FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags(); 58bf4b31adSvporpo if (any_of(drop_begin(Bndl), [FMF0](auto *V) { 59bf4b31adSvporpo return cast<Instruction>(V)->getFastMathFlags() != FMF0; 60bf4b31adSvporpo })) 61bf4b31adSvporpo return ResultReason::DiffMathFlags; 62bf4b31adSvporpo } 63bf4b31adSvporpo 64ca998b07Svporpo // TODO: Allow vectorization by using common flags. 65ca998b07Svporpo // For now Pack if they don't have the same wrap flags. 66ca998b07Svporpo bool CanHaveWrapFlags = 67ca998b07Svporpo isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0); 68ca998b07Svporpo if (CanHaveWrapFlags) { 69ca998b07Svporpo bool NUW0 = I0->hasNoUnsignedWrap(); 70ca998b07Svporpo bool NSW0 = I0->hasNoSignedWrap(); 71ca998b07Svporpo if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) { 72ca998b07Svporpo return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 || 73ca998b07Svporpo cast<Instruction>(V)->hasNoSignedWrap() != NSW0; 74ca998b07Svporpo })) { 75ca998b07Svporpo return ResultReason::DiffWrapFlags; 76ca998b07Svporpo } 77ca998b07Svporpo } 78ca998b07Svporpo 79083369fdSvporpo // Now we need to do further checks for specific opcodes. 80083369fdSvporpo switch (Opcode) { 81083369fdSvporpo case Instruction::Opcode::ZExt: 82083369fdSvporpo case Instruction::Opcode::SExt: 83083369fdSvporpo case Instruction::Opcode::FPToUI: 84083369fdSvporpo case Instruction::Opcode::FPToSI: 85083369fdSvporpo case Instruction::Opcode::FPExt: 86083369fdSvporpo case Instruction::Opcode::PtrToInt: 87083369fdSvporpo case Instruction::Opcode::IntToPtr: 88083369fdSvporpo case Instruction::Opcode::SIToFP: 89083369fdSvporpo case Instruction::Opcode::UIToFP: 90083369fdSvporpo case Instruction::Opcode::Trunc: 91083369fdSvporpo case Instruction::Opcode::FPTrunc: 92083369fdSvporpo case Instruction::Opcode::BitCast: { 93083369fdSvporpo // We have already checked that they are of the same opcode. 94083369fdSvporpo assert(all_of(Bndl, 95083369fdSvporpo [Opcode](Value *V) { 96083369fdSvporpo return cast<Instruction>(V)->getOpcode() == Opcode; 97083369fdSvporpo }) && 98083369fdSvporpo "Different opcodes, should have early returned!"); 99083369fdSvporpo // But for these opcodes we should also check the operand type. 100083369fdSvporpo Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0)); 101083369fdSvporpo if (any_of(drop_begin(Bndl), [FromTy0](Value *V) { 102083369fdSvporpo return Utils::getExpectedType(cast<User>(V)->getOperand(0)) != 103083369fdSvporpo FromTy0; 104083369fdSvporpo })) 105083369fdSvporpo return ResultReason::DiffTypes; 106083369fdSvporpo return std::nullopt; 107083369fdSvporpo } 108083369fdSvporpo case Instruction::Opcode::FCmp: 109083369fdSvporpo case Instruction::Opcode::ICmp: { 110083369fdSvporpo // We need the same predicate.. 111083369fdSvporpo auto Pred0 = cast<CmpInst>(I0)->getPredicate(); 112083369fdSvporpo bool Same = all_of(Bndl, [Pred0](Value *V) { 113083369fdSvporpo return cast<CmpInst>(V)->getPredicate() == Pred0; 114083369fdSvporpo }); 115083369fdSvporpo if (Same) 116083369fdSvporpo return std::nullopt; 117083369fdSvporpo return ResultReason::DiffOpcodes; 118083369fdSvporpo } 119083369fdSvporpo case Instruction::Opcode::Select: 120083369fdSvporpo case Instruction::Opcode::FNeg: 121083369fdSvporpo case Instruction::Opcode::Add: 122083369fdSvporpo case Instruction::Opcode::FAdd: 123083369fdSvporpo case Instruction::Opcode::Sub: 124083369fdSvporpo case Instruction::Opcode::FSub: 125083369fdSvporpo case Instruction::Opcode::Mul: 126083369fdSvporpo case Instruction::Opcode::FMul: 127083369fdSvporpo case Instruction::Opcode::FRem: 128083369fdSvporpo case Instruction::Opcode::UDiv: 129083369fdSvporpo case Instruction::Opcode::SDiv: 130083369fdSvporpo case Instruction::Opcode::FDiv: 131083369fdSvporpo case Instruction::Opcode::URem: 132083369fdSvporpo case Instruction::Opcode::SRem: 133083369fdSvporpo case Instruction::Opcode::Shl: 134083369fdSvporpo case Instruction::Opcode::LShr: 135083369fdSvporpo case Instruction::Opcode::AShr: 136083369fdSvporpo case Instruction::Opcode::And: 137083369fdSvporpo case Instruction::Opcode::Or: 138083369fdSvporpo case Instruction::Opcode::Xor: 139083369fdSvporpo return std::nullopt; 140083369fdSvporpo case Instruction::Opcode::Load: 141083369fdSvporpo if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL)) 142083369fdSvporpo return std::nullopt; 143083369fdSvporpo return ResultReason::NotConsecutive; 144083369fdSvporpo case Instruction::Opcode::Store: 145083369fdSvporpo if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL)) 146083369fdSvporpo return std::nullopt; 147083369fdSvporpo return ResultReason::NotConsecutive; 148083369fdSvporpo case Instruction::Opcode::PHI: 149083369fdSvporpo return ResultReason::Unimplemented; 150083369fdSvporpo case Instruction::Opcode::Opaque: 151083369fdSvporpo return ResultReason::Unimplemented; 152083369fdSvporpo case Instruction::Opcode::Br: 153083369fdSvporpo case Instruction::Opcode::Ret: 154083369fdSvporpo case Instruction::Opcode::AddrSpaceCast: 155083369fdSvporpo case Instruction::Opcode::InsertElement: 156083369fdSvporpo case Instruction::Opcode::InsertValue: 157083369fdSvporpo case Instruction::Opcode::ExtractElement: 158083369fdSvporpo case Instruction::Opcode::ExtractValue: 159083369fdSvporpo case Instruction::Opcode::ShuffleVector: 160083369fdSvporpo case Instruction::Opcode::Call: 161083369fdSvporpo case Instruction::Opcode::GetElementPtr: 162083369fdSvporpo case Instruction::Opcode::Switch: 163083369fdSvporpo return ResultReason::Unimplemented; 164083369fdSvporpo case Instruction::Opcode::VAArg: 165083369fdSvporpo case Instruction::Opcode::Freeze: 166083369fdSvporpo case Instruction::Opcode::Fence: 167083369fdSvporpo case Instruction::Opcode::Invoke: 168083369fdSvporpo case Instruction::Opcode::CallBr: 169083369fdSvporpo case Instruction::Opcode::LandingPad: 170083369fdSvporpo case Instruction::Opcode::CatchPad: 171083369fdSvporpo case Instruction::Opcode::CleanupPad: 172083369fdSvporpo case Instruction::Opcode::CatchRet: 173083369fdSvporpo case Instruction::Opcode::CleanupRet: 174083369fdSvporpo case Instruction::Opcode::Resume: 175083369fdSvporpo case Instruction::Opcode::CatchSwitch: 176083369fdSvporpo case Instruction::Opcode::AtomicRMW: 177083369fdSvporpo case Instruction::Opcode::AtomicCmpXchg: 178083369fdSvporpo case Instruction::Opcode::Alloca: 179083369fdSvporpo case Instruction::Opcode::Unreachable: 180083369fdSvporpo return ResultReason::Infeasible; 181083369fdSvporpo } 1825ea69481Svporpo 18354c93aabSvporpo return std::nullopt; 18454c93aabSvporpo } 18554c93aabSvporpo 1861540f772SVasileios Porpodas #ifndef NDEBUG 1871540f772SVasileios Porpodas static void dumpBndl(ArrayRef<Value *> Bndl) { 1881540f772SVasileios Porpodas for (auto *V : Bndl) 1891540f772SVasileios Porpodas dbgs() << *V << "\n"; 1901540f772SVasileios Porpodas } 1911540f772SVasileios Porpodas #endif // NDEBUG 1921540f772SVasileios Porpodas 193e902c696Svporpo CollectDescr 194e902c696Svporpo LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const { 195e902c696Svporpo SmallVector<CollectDescr::ExtractElementDescr, 4> Vec; 196e902c696Svporpo Vec.reserve(Bndl.size()); 197e902c696Svporpo for (auto [Lane, V] : enumerate(Bndl)) { 198e902c696Svporpo if (auto *VecOp = IMaps.getVectorForOrig(V)) { 199e902c696Svporpo // If there is a vector containing `V`, then get the lane it came from. 200e902c696Svporpo std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V); 201e902c696Svporpo Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1); 202e902c696Svporpo } else { 203e902c696Svporpo Vec.emplace_back(V); 204e902c696Svporpo } 205e902c696Svporpo } 206e902c696Svporpo return CollectDescr(std::move(Vec)); 207e902c696Svporpo } 208e902c696Svporpo 209ce0d0858Svporpo const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl, 210ce0d0858Svporpo bool SkipScheduling) { 2111540f772SVasileios Porpodas // If Bndl contains values other than instructions, we need to Pack. 2121540f772SVasileios Porpodas if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) { 2131540f772SVasileios Porpodas LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n"; 2141540f772SVasileios Porpodas dumpBndl(Bndl);); 2151540f772SVasileios Porpodas return createLegalityResult<Pack>(ResultReason::NotInstructions); 2161540f772SVasileios Porpodas } 217*6409799bSvporpo // Pack if not in the same BB. 218*6409799bSvporpo auto *BB = cast<Instruction>(Bndl[0])->getParent(); 219*6409799bSvporpo if (any_of(drop_begin(Bndl), 220*6409799bSvporpo [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; })) 221*6409799bSvporpo return createLegalityResult<Pack>(ResultReason::DiffBBs); 2221540f772SVasileios Porpodas 223e902c696Svporpo auto CollectDescrs = getHowToCollectValues(Bndl); 224e902c696Svporpo if (CollectDescrs.hasVectorInputs()) { 225e902c696Svporpo if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { 22687e4b681Svporpo auto [Vec, Mask] = *ValueShuffleOpt; 22787e4b681Svporpo if (Mask.isIdentity()) 228e902c696Svporpo return createLegalityResult<DiamondReuse>(Vec); 22987e4b681Svporpo return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask); 230e902c696Svporpo } 231fd087135Svporpo return createLegalityResult<DiamondReuseMultiInput>( 232fd087135Svporpo std::move(CollectDescrs)); 233e902c696Svporpo } 234e902c696Svporpo 23554c93aabSvporpo if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) 23654c93aabSvporpo return createLegalityResult<Pack>(*ReasonOpt); 23754c93aabSvporpo 238ce0d0858Svporpo if (!SkipScheduling) { 239ce0d0858Svporpo // TODO: Try to remove the IBndl vector. 240ce0d0858Svporpo SmallVector<Instruction *, 8> IBndl; 241ce0d0858Svporpo IBndl.reserve(Bndl.size()); 242ce0d0858Svporpo for (auto *V : Bndl) 243ce0d0858Svporpo IBndl.push_back(cast<Instruction>(V)); 244ce0d0858Svporpo if (!Sched.trySchedule(IBndl)) 245ce0d0858Svporpo return createLegalityResult<Pack>(ResultReason::CantSchedule); 246ce0d0858Svporpo } 24754c93aabSvporpo 24854c93aabSvporpo return createLegalityResult<Widen>(); 24954c93aabSvporpo } 250e902c696Svporpo 251e902c696Svporpo void LegalityAnalysis::clear() { 252e902c696Svporpo Sched.clear(); 253e902c696Svporpo IMaps.clear(); 254e902c696Svporpo } 25554c93aabSvporpo } // namespace llvm::sandboxir 256