1 //===- Legality.cpp -------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" 10 #include "llvm/SandboxIR/Instruction.h" 11 #include "llvm/SandboxIR/Operator.h" 12 #include "llvm/SandboxIR/Utils.h" 13 #include "llvm/SandboxIR/Value.h" 14 #include "llvm/Support/Debug.h" 15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 17 18 namespace llvm::sandboxir { 19 20 #define DEBUG_TYPE "SBVec:Legality" 21 22 #ifndef NDEBUG 23 void LegalityResult::dump() const { 24 print(dbgs()); 25 dbgs() << "\n"; 26 } 27 #endif // NDEBUG 28 29 std::optional<ResultReason> 30 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( 31 ArrayRef<Value *> Bndl) { 32 auto *I0 = cast<Instruction>(Bndl[0]); 33 auto Opcode = I0->getOpcode(); 34 // If they have different opcodes, then we cannot form a vector (for now). 35 if (any_of(drop_begin(Bndl), [Opcode](Value *V) { 36 return cast<Instruction>(V)->getOpcode() != Opcode; 37 })) 38 return ResultReason::DiffOpcodes; 39 40 // If not the same scalar type, Pack. This will accept scalars and vectors as 41 // long as the element type is the same. 42 Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0)); 43 if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) { 44 return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0; 45 })) 46 return ResultReason::DiffTypes; 47 48 // TODO: Allow vectorization of instrs with different flags as long as we 49 // change them to the least common one. 50 // For now pack if differnt FastMathFlags. 51 if (isa<FPMathOperator>(I0)) { 52 FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags(); 53 if (any_of(drop_begin(Bndl), [FMF0](auto *V) { 54 return cast<Instruction>(V)->getFastMathFlags() != FMF0; 55 })) 56 return ResultReason::DiffMathFlags; 57 } 58 59 // TODO: Allow vectorization by using common flags. 60 // For now Pack if they don't have the same wrap flags. 61 bool CanHaveWrapFlags = 62 isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0); 63 if (CanHaveWrapFlags) { 64 bool NUW0 = I0->hasNoUnsignedWrap(); 65 bool NSW0 = I0->hasNoSignedWrap(); 66 if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) { 67 return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 || 68 cast<Instruction>(V)->hasNoSignedWrap() != NSW0; 69 })) { 70 return ResultReason::DiffWrapFlags; 71 } 72 } 73 74 // Now we need to do further checks for specific opcodes. 75 switch (Opcode) { 76 case Instruction::Opcode::ZExt: 77 case Instruction::Opcode::SExt: 78 case Instruction::Opcode::FPToUI: 79 case Instruction::Opcode::FPToSI: 80 case Instruction::Opcode::FPExt: 81 case Instruction::Opcode::PtrToInt: 82 case Instruction::Opcode::IntToPtr: 83 case Instruction::Opcode::SIToFP: 84 case Instruction::Opcode::UIToFP: 85 case Instruction::Opcode::Trunc: 86 case Instruction::Opcode::FPTrunc: 87 case Instruction::Opcode::BitCast: { 88 // We have already checked that they are of the same opcode. 89 assert(all_of(Bndl, 90 [Opcode](Value *V) { 91 return cast<Instruction>(V)->getOpcode() == Opcode; 92 }) && 93 "Different opcodes, should have early returned!"); 94 // But for these opcodes we should also check the operand type. 95 Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0)); 96 if (any_of(drop_begin(Bndl), [FromTy0](Value *V) { 97 return Utils::getExpectedType(cast<User>(V)->getOperand(0)) != 98 FromTy0; 99 })) 100 return ResultReason::DiffTypes; 101 return std::nullopt; 102 } 103 case Instruction::Opcode::FCmp: 104 case Instruction::Opcode::ICmp: { 105 // We need the same predicate.. 106 auto Pred0 = cast<CmpInst>(I0)->getPredicate(); 107 bool Same = all_of(Bndl, [Pred0](Value *V) { 108 return cast<CmpInst>(V)->getPredicate() == Pred0; 109 }); 110 if (Same) 111 return std::nullopt; 112 return ResultReason::DiffOpcodes; 113 } 114 case Instruction::Opcode::Select: 115 case Instruction::Opcode::FNeg: 116 case Instruction::Opcode::Add: 117 case Instruction::Opcode::FAdd: 118 case Instruction::Opcode::Sub: 119 case Instruction::Opcode::FSub: 120 case Instruction::Opcode::Mul: 121 case Instruction::Opcode::FMul: 122 case Instruction::Opcode::FRem: 123 case Instruction::Opcode::UDiv: 124 case Instruction::Opcode::SDiv: 125 case Instruction::Opcode::FDiv: 126 case Instruction::Opcode::URem: 127 case Instruction::Opcode::SRem: 128 case Instruction::Opcode::Shl: 129 case Instruction::Opcode::LShr: 130 case Instruction::Opcode::AShr: 131 case Instruction::Opcode::And: 132 case Instruction::Opcode::Or: 133 case Instruction::Opcode::Xor: 134 return std::nullopt; 135 case Instruction::Opcode::Load: 136 if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL)) 137 return std::nullopt; 138 return ResultReason::NotConsecutive; 139 case Instruction::Opcode::Store: 140 if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL)) 141 return std::nullopt; 142 return ResultReason::NotConsecutive; 143 case Instruction::Opcode::PHI: 144 return ResultReason::Unimplemented; 145 case Instruction::Opcode::Opaque: 146 return ResultReason::Unimplemented; 147 case Instruction::Opcode::Br: 148 case Instruction::Opcode::Ret: 149 case Instruction::Opcode::AddrSpaceCast: 150 case Instruction::Opcode::InsertElement: 151 case Instruction::Opcode::InsertValue: 152 case Instruction::Opcode::ExtractElement: 153 case Instruction::Opcode::ExtractValue: 154 case Instruction::Opcode::ShuffleVector: 155 case Instruction::Opcode::Call: 156 case Instruction::Opcode::GetElementPtr: 157 case Instruction::Opcode::Switch: 158 return ResultReason::Unimplemented; 159 case Instruction::Opcode::VAArg: 160 case Instruction::Opcode::Freeze: 161 case Instruction::Opcode::Fence: 162 case Instruction::Opcode::Invoke: 163 case Instruction::Opcode::CallBr: 164 case Instruction::Opcode::LandingPad: 165 case Instruction::Opcode::CatchPad: 166 case Instruction::Opcode::CleanupPad: 167 case Instruction::Opcode::CatchRet: 168 case Instruction::Opcode::CleanupRet: 169 case Instruction::Opcode::Resume: 170 case Instruction::Opcode::CatchSwitch: 171 case Instruction::Opcode::AtomicRMW: 172 case Instruction::Opcode::AtomicCmpXchg: 173 case Instruction::Opcode::Alloca: 174 case Instruction::Opcode::Unreachable: 175 return ResultReason::Infeasible; 176 } 177 178 return std::nullopt; 179 } 180 181 #ifndef NDEBUG 182 static void dumpBndl(ArrayRef<Value *> Bndl) { 183 for (auto *V : Bndl) 184 dbgs() << *V << "\n"; 185 } 186 #endif // NDEBUG 187 188 CollectDescr 189 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const { 190 SmallVector<CollectDescr::ExtractElementDescr, 4> Vec; 191 Vec.reserve(Bndl.size()); 192 for (auto [Lane, V] : enumerate(Bndl)) { 193 if (auto *VecOp = IMaps.getVectorForOrig(V)) { 194 // If there is a vector containing `V`, then get the lane it came from. 195 std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V); 196 Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1); 197 } else { 198 Vec.emplace_back(V); 199 } 200 } 201 return CollectDescr(std::move(Vec)); 202 } 203 204 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl, 205 bool SkipScheduling) { 206 // If Bndl contains values other than instructions, we need to Pack. 207 if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) { 208 LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n"; 209 dumpBndl(Bndl);); 210 return createLegalityResult<Pack>(ResultReason::NotInstructions); 211 } 212 213 auto CollectDescrs = getHowToCollectValues(Bndl); 214 if (CollectDescrs.hasVectorInputs()) { 215 if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { 216 auto [Vec, NeedsShuffle] = *ValueShuffleOpt; 217 if (!NeedsShuffle) 218 return createLegalityResult<DiamondReuse>(Vec); 219 llvm_unreachable("TODO: Unimplemented"); 220 } else { 221 llvm_unreachable("TODO: Unimplemented"); 222 } 223 } 224 225 if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) 226 return createLegalityResult<Pack>(*ReasonOpt); 227 228 if (!SkipScheduling) { 229 // TODO: Try to remove the IBndl vector. 230 SmallVector<Instruction *, 8> IBndl; 231 IBndl.reserve(Bndl.size()); 232 for (auto *V : Bndl) 233 IBndl.push_back(cast<Instruction>(V)); 234 if (!Sched.trySchedule(IBndl)) 235 return createLegalityResult<Pack>(ResultReason::CantSchedule); 236 } 237 238 return createLegalityResult<Widen>(); 239 } 240 241 void LegalityAnalysis::clear() { 242 Sched.clear(); 243 IMaps.clear(); 244 } 245 } // namespace llvm::sandboxir 246