1 //===- Legality.cpp -------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" 10 #include "llvm/SandboxIR/Instruction.h" 11 #include "llvm/SandboxIR/Operator.h" 12 #include "llvm/SandboxIR/Utils.h" 13 #include "llvm/SandboxIR/Value.h" 14 #include "llvm/Support/Debug.h" 15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 17 18 namespace llvm::sandboxir { 19 20 #define DEBUG_TYPE "SBVec:Legality" 21 22 #ifndef NDEBUG 23 void ShuffleMask::dump() const { 24 print(dbgs()); 25 dbgs() << "\n"; 26 } 27 28 void LegalityResult::dump() const { 29 print(dbgs()); 30 dbgs() << "\n"; 31 } 32 #endif // NDEBUG 33 34 std::optional<ResultReason> 35 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( 36 ArrayRef<Value *> Bndl) { 37 auto *I0 = cast<Instruction>(Bndl[0]); 38 auto Opcode = I0->getOpcode(); 39 // If they have different opcodes, then we cannot form a vector (for now). 40 if (any_of(drop_begin(Bndl), [Opcode](Value *V) { 41 return cast<Instruction>(V)->getOpcode() != Opcode; 42 })) 43 return ResultReason::DiffOpcodes; 44 45 // If not the same scalar type, Pack. This will accept scalars and vectors as 46 // long as the element type is the same. 47 Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0)); 48 if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) { 49 return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0; 50 })) 51 return ResultReason::DiffTypes; 52 53 // TODO: Allow vectorization of instrs with different flags as long as we 54 // change them to the least common one. 55 // For now pack if differnt FastMathFlags. 56 if (isa<FPMathOperator>(I0)) { 57 FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags(); 58 if (any_of(drop_begin(Bndl), [FMF0](auto *V) { 59 return cast<Instruction>(V)->getFastMathFlags() != FMF0; 60 })) 61 return ResultReason::DiffMathFlags; 62 } 63 64 // TODO: Allow vectorization by using common flags. 65 // For now Pack if they don't have the same wrap flags. 66 bool CanHaveWrapFlags = 67 isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0); 68 if (CanHaveWrapFlags) { 69 bool NUW0 = I0->hasNoUnsignedWrap(); 70 bool NSW0 = I0->hasNoSignedWrap(); 71 if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) { 72 return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 || 73 cast<Instruction>(V)->hasNoSignedWrap() != NSW0; 74 })) { 75 return ResultReason::DiffWrapFlags; 76 } 77 } 78 79 // Now we need to do further checks for specific opcodes. 80 switch (Opcode) { 81 case Instruction::Opcode::ZExt: 82 case Instruction::Opcode::SExt: 83 case Instruction::Opcode::FPToUI: 84 case Instruction::Opcode::FPToSI: 85 case Instruction::Opcode::FPExt: 86 case Instruction::Opcode::PtrToInt: 87 case Instruction::Opcode::IntToPtr: 88 case Instruction::Opcode::SIToFP: 89 case Instruction::Opcode::UIToFP: 90 case Instruction::Opcode::Trunc: 91 case Instruction::Opcode::FPTrunc: 92 case Instruction::Opcode::BitCast: { 93 // We have already checked that they are of the same opcode. 94 assert(all_of(Bndl, 95 [Opcode](Value *V) { 96 return cast<Instruction>(V)->getOpcode() == Opcode; 97 }) && 98 "Different opcodes, should have early returned!"); 99 // But for these opcodes we should also check the operand type. 100 Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0)); 101 if (any_of(drop_begin(Bndl), [FromTy0](Value *V) { 102 return Utils::getExpectedType(cast<User>(V)->getOperand(0)) != 103 FromTy0; 104 })) 105 return ResultReason::DiffTypes; 106 return std::nullopt; 107 } 108 case Instruction::Opcode::FCmp: 109 case Instruction::Opcode::ICmp: { 110 // We need the same predicate.. 111 auto Pred0 = cast<CmpInst>(I0)->getPredicate(); 112 bool Same = all_of(Bndl, [Pred0](Value *V) { 113 return cast<CmpInst>(V)->getPredicate() == Pred0; 114 }); 115 if (Same) 116 return std::nullopt; 117 return ResultReason::DiffOpcodes; 118 } 119 case Instruction::Opcode::Select: 120 case Instruction::Opcode::FNeg: 121 case Instruction::Opcode::Add: 122 case Instruction::Opcode::FAdd: 123 case Instruction::Opcode::Sub: 124 case Instruction::Opcode::FSub: 125 case Instruction::Opcode::Mul: 126 case Instruction::Opcode::FMul: 127 case Instruction::Opcode::FRem: 128 case Instruction::Opcode::UDiv: 129 case Instruction::Opcode::SDiv: 130 case Instruction::Opcode::FDiv: 131 case Instruction::Opcode::URem: 132 case Instruction::Opcode::SRem: 133 case Instruction::Opcode::Shl: 134 case Instruction::Opcode::LShr: 135 case Instruction::Opcode::AShr: 136 case Instruction::Opcode::And: 137 case Instruction::Opcode::Or: 138 case Instruction::Opcode::Xor: 139 return std::nullopt; 140 case Instruction::Opcode::Load: 141 if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL)) 142 return std::nullopt; 143 return ResultReason::NotConsecutive; 144 case Instruction::Opcode::Store: 145 if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL)) 146 return std::nullopt; 147 return ResultReason::NotConsecutive; 148 case Instruction::Opcode::PHI: 149 return ResultReason::Unimplemented; 150 case Instruction::Opcode::Opaque: 151 return ResultReason::Unimplemented; 152 case Instruction::Opcode::Br: 153 case Instruction::Opcode::Ret: 154 case Instruction::Opcode::AddrSpaceCast: 155 case Instruction::Opcode::InsertElement: 156 case Instruction::Opcode::InsertValue: 157 case Instruction::Opcode::ExtractElement: 158 case Instruction::Opcode::ExtractValue: 159 case Instruction::Opcode::ShuffleVector: 160 case Instruction::Opcode::Call: 161 case Instruction::Opcode::GetElementPtr: 162 case Instruction::Opcode::Switch: 163 return ResultReason::Unimplemented; 164 case Instruction::Opcode::VAArg: 165 case Instruction::Opcode::Freeze: 166 case Instruction::Opcode::Fence: 167 case Instruction::Opcode::Invoke: 168 case Instruction::Opcode::CallBr: 169 case Instruction::Opcode::LandingPad: 170 case Instruction::Opcode::CatchPad: 171 case Instruction::Opcode::CleanupPad: 172 case Instruction::Opcode::CatchRet: 173 case Instruction::Opcode::CleanupRet: 174 case Instruction::Opcode::Resume: 175 case Instruction::Opcode::CatchSwitch: 176 case Instruction::Opcode::AtomicRMW: 177 case Instruction::Opcode::AtomicCmpXchg: 178 case Instruction::Opcode::Alloca: 179 case Instruction::Opcode::Unreachable: 180 return ResultReason::Infeasible; 181 } 182 183 return std::nullopt; 184 } 185 186 #ifndef NDEBUG 187 static void dumpBndl(ArrayRef<Value *> Bndl) { 188 for (auto *V : Bndl) 189 dbgs() << *V << "\n"; 190 } 191 #endif // NDEBUG 192 193 CollectDescr 194 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const { 195 SmallVector<CollectDescr::ExtractElementDescr, 4> Vec; 196 Vec.reserve(Bndl.size()); 197 for (auto [Lane, V] : enumerate(Bndl)) { 198 if (auto *VecOp = IMaps.getVectorForOrig(V)) { 199 // If there is a vector containing `V`, then get the lane it came from. 200 std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V); 201 Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1); 202 } else { 203 Vec.emplace_back(V); 204 } 205 } 206 return CollectDescr(std::move(Vec)); 207 } 208 209 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl, 210 bool SkipScheduling) { 211 // If Bndl contains values other than instructions, we need to Pack. 212 if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) { 213 LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n"; 214 dumpBndl(Bndl);); 215 return createLegalityResult<Pack>(ResultReason::NotInstructions); 216 } 217 // Pack if not in the same BB. 218 auto *BB = cast<Instruction>(Bndl[0])->getParent(); 219 if (any_of(drop_begin(Bndl), 220 [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; })) 221 return createLegalityResult<Pack>(ResultReason::DiffBBs); 222 223 auto CollectDescrs = getHowToCollectValues(Bndl); 224 if (CollectDescrs.hasVectorInputs()) { 225 if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { 226 auto [Vec, Mask] = *ValueShuffleOpt; 227 if (Mask.isIdentity()) 228 return createLegalityResult<DiamondReuse>(Vec); 229 return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask); 230 } 231 return createLegalityResult<DiamondReuseMultiInput>( 232 std::move(CollectDescrs)); 233 } 234 235 if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) 236 return createLegalityResult<Pack>(*ReasonOpt); 237 238 if (!SkipScheduling) { 239 // TODO: Try to remove the IBndl vector. 240 SmallVector<Instruction *, 8> IBndl; 241 IBndl.reserve(Bndl.size()); 242 for (auto *V : Bndl) 243 IBndl.push_back(cast<Instruction>(V)); 244 if (!Sched.trySchedule(IBndl)) 245 return createLegalityResult<Pack>(ResultReason::CantSchedule); 246 } 247 248 return createLegalityResult<Widen>(); 249 } 250 251 void LegalityAnalysis::clear() { 252 Sched.clear(); 253 IMaps.clear(); 254 } 255 } // namespace llvm::sandboxir 256