1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the targeting of the Machinelegalizer class for SPIR-V. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVLegalizerInfo.h" 14 #include "SPIRV.h" 15 #include "SPIRVGlobalRegistry.h" 16 #include "SPIRVSubtarget.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/MachineInstr.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 23 using namespace llvm; 24 using namespace llvm::LegalizeActions; 25 using namespace llvm::LegalityPredicates; 26 27 // clang-format off 28 static const std::set<unsigned> TypeFoldingSupportingOpcs = { 29 TargetOpcode::G_ADD, 30 TargetOpcode::G_FADD, 31 TargetOpcode::G_STRICT_FADD, 32 TargetOpcode::G_SUB, 33 TargetOpcode::G_FSUB, 34 TargetOpcode::G_STRICT_FSUB, 35 TargetOpcode::G_MUL, 36 TargetOpcode::G_FMUL, 37 TargetOpcode::G_STRICT_FMUL, 38 TargetOpcode::G_SDIV, 39 TargetOpcode::G_UDIV, 40 TargetOpcode::G_FDIV, 41 TargetOpcode::G_STRICT_FDIV, 42 TargetOpcode::G_SREM, 43 TargetOpcode::G_UREM, 44 TargetOpcode::G_FREM, 45 TargetOpcode::G_STRICT_FREM, 46 TargetOpcode::G_FNEG, 47 TargetOpcode::G_CONSTANT, 48 TargetOpcode::G_FCONSTANT, 49 TargetOpcode::G_AND, 50 TargetOpcode::G_OR, 51 TargetOpcode::G_XOR, 52 TargetOpcode::G_SHL, 53 TargetOpcode::G_ASHR, 54 TargetOpcode::G_LSHR, 55 TargetOpcode::G_SELECT, 56 TargetOpcode::G_EXTRACT_VECTOR_ELT, 57 }; 58 // clang-format on 59 60 bool isTypeFoldingSupported(unsigned Opcode) { 61 return TypeFoldingSupportingOpcs.count(Opcode) > 0; 62 } 63 64 LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) { 65 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) { 66 const LLT Ty = Query.Types[TypeIdx]; 67 return IsExtendedInts && Ty.isValid() && Ty.isScalar(); 68 }; 69 } 70 71 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { 72 using namespace TargetOpcode; 73 74 this->ST = &ST; 75 GR = ST.getSPIRVGlobalRegistry(); 76 77 const LLT s1 = LLT::scalar(1); 78 const LLT s8 = LLT::scalar(8); 79 const LLT s16 = LLT::scalar(16); 80 const LLT s32 = LLT::scalar(32); 81 const LLT s64 = LLT::scalar(64); 82 83 const LLT v16s64 = LLT::fixed_vector(16, 64); 84 const LLT v16s32 = LLT::fixed_vector(16, 32); 85 const LLT v16s16 = LLT::fixed_vector(16, 16); 86 const LLT v16s8 = LLT::fixed_vector(16, 8); 87 const LLT v16s1 = LLT::fixed_vector(16, 1); 88 89 const LLT v8s64 = LLT::fixed_vector(8, 64); 90 const LLT v8s32 = LLT::fixed_vector(8, 32); 91 const LLT v8s16 = LLT::fixed_vector(8, 16); 92 const LLT v8s8 = LLT::fixed_vector(8, 8); 93 const LLT v8s1 = LLT::fixed_vector(8, 1); 94 95 const LLT v4s64 = LLT::fixed_vector(4, 64); 96 const LLT v4s32 = LLT::fixed_vector(4, 32); 97 const LLT v4s16 = LLT::fixed_vector(4, 16); 98 const LLT v4s8 = LLT::fixed_vector(4, 8); 99 const LLT v4s1 = LLT::fixed_vector(4, 1); 100 101 const LLT v3s64 = LLT::fixed_vector(3, 64); 102 const LLT v3s32 = LLT::fixed_vector(3, 32); 103 const LLT v3s16 = LLT::fixed_vector(3, 16); 104 const LLT v3s8 = LLT::fixed_vector(3, 8); 105 const LLT v3s1 = LLT::fixed_vector(3, 1); 106 107 const LLT v2s64 = LLT::fixed_vector(2, 64); 108 const LLT v2s32 = LLT::fixed_vector(2, 32); 109 const LLT v2s16 = LLT::fixed_vector(2, 16); 110 const LLT v2s8 = LLT::fixed_vector(2, 8); 111 const LLT v2s1 = LLT::fixed_vector(2, 1); 112 113 const unsigned PSize = ST.getPointerSize(); 114 const LLT p0 = LLT::pointer(0, PSize); // Function 115 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup 116 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant 117 const LLT p3 = LLT::pointer(3, PSize); // Workgroup 118 const LLT p4 = LLT::pointer(4, PSize); // Generic 119 const LLT p5 = 120 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device) 121 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host) 122 const LLT p7 = LLT::pointer(7, PSize); // Input 123 const LLT p8 = LLT::pointer(8, PSize); // Output 124 const LLT p10 = LLT::pointer(10, PSize); // Private 125 126 // TODO: remove copy-pasting here by using concatenation in some way. 127 auto allPtrsScalarsAndVectors = { 128 p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, 129 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 130 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 131 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 132 133 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, 134 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, 135 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, 136 v16s8, v16s16, v16s32, v16s64}; 137 138 auto allScalarsAndVectors = { 139 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 140 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 141 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 142 143 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, 144 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, 145 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, 146 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; 147 148 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; 149 150 auto allIntScalars = {s8, s16, s32, s64}; 151 152 auto allFloatScalars = {s16, s32, s64}; 153 154 auto allFloatScalarsAndVectors = { 155 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, 156 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; 157 158 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, 159 p3, p4, p5, p6, p7, p8, p10}; 160 161 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10}; 162 163 bool IsExtendedInts = 164 ST.canUseExtension( 165 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) || 166 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 167 auto extendedScalarsAndVectors = 168 [IsExtendedInts](const LegalityQuery &Query) { 169 const LLT Ty = Query.Types[0]; 170 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector(); 171 }; 172 auto extendedScalarsAndVectorsProduct = [IsExtendedInts]( 173 const LegalityQuery &Query) { 174 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1]; 175 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() && 176 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector(); 177 }; 178 auto extendedPtrsScalarsAndVectors = 179 [IsExtendedInts](const LegalityQuery &Query) { 180 const LLT Ty = Query.Types[0]; 181 return IsExtendedInts && Ty.isValid(); 182 }; 183 184 for (auto Opc : TypeFoldingSupportingOpcs) 185 getActionDefinitionsBuilder(Opc).custom(); 186 187 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); 188 189 // TODO: add proper rules for vectors legalization. 190 getActionDefinitionsBuilder( 191 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) 192 .alwaysLegal(); 193 194 // Vector Reduction Operations 195 getActionDefinitionsBuilder( 196 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX, 197 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, 198 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, 199 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) 200 .legalFor(allVectors) 201 .scalarize(1) 202 .lower(); 203 204 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL}) 205 .scalarize(2) 206 .lower(); 207 208 // Merge/Unmerge 209 // TODO: add proper legalization rules. 210 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal(); 211 212 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) 213 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs))); 214 215 getActionDefinitionsBuilder(G_MEMSET).legalIf( 216 all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars))); 217 218 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 219 .legalForCartesianProduct(allPtrs, allPtrs); 220 221 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); 222 223 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS, 224 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT, 225 G_USUBSAT, G_SCMP, G_UCMP}) 226 .legalFor(allIntScalarsAndVectors) 227 .legalIf(extendedScalarsAndVectors); 228 229 getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA}) 230 .legalFor(allFloatScalarsAndVectors); 231 232 getActionDefinitionsBuilder(G_STRICT_FLDEXP) 233 .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars); 234 235 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 236 .legalForCartesianProduct(allIntScalarsAndVectors, 237 allFloatScalarsAndVectors); 238 239 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 240 .legalForCartesianProduct(allFloatScalarsAndVectors, 241 allScalarsAndVectors); 242 243 getActionDefinitionsBuilder(G_CTPOP) 244 .legalForCartesianProduct(allIntScalarsAndVectors) 245 .legalIf(extendedScalarsAndVectorsProduct); 246 247 // Extensions. 248 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) 249 .legalForCartesianProduct(allScalarsAndVectors) 250 .legalIf(extendedScalarsAndVectorsProduct); 251 252 getActionDefinitionsBuilder(G_PHI) 253 .legalFor(allPtrsScalarsAndVectors) 254 .legalIf(extendedPtrsScalarsAndVectors); 255 256 getActionDefinitionsBuilder(G_BITCAST).legalIf( 257 all(typeInSet(0, allPtrsScalarsAndVectors), 258 typeInSet(1, allPtrsScalarsAndVectors))); 259 260 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); 261 262 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); 263 264 getActionDefinitionsBuilder(G_INTTOPTR) 265 .legalForCartesianProduct(allPtrs, allIntScalars) 266 .legalIf( 267 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts))); 268 getActionDefinitionsBuilder(G_PTRTOINT) 269 .legalForCartesianProduct(allIntScalars, allPtrs) 270 .legalIf( 271 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs))); 272 getActionDefinitionsBuilder(G_PTR_ADD) 273 .legalForCartesianProduct(allPtrs, allIntScalars) 274 .legalIf( 275 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts))); 276 277 // ST.canDirectlyComparePointers() for pointer args is supported in 278 // legalizeCustom(). 279 getActionDefinitionsBuilder(G_ICMP).customIf( 280 all(typeInSet(0, allBoolScalarsAndVectors), 281 typeInSet(1, allPtrsScalarsAndVectors))); 282 283 getActionDefinitionsBuilder(G_FCMP).legalIf( 284 all(typeInSet(0, allBoolScalarsAndVectors), 285 typeInSet(1, allFloatScalarsAndVectors))); 286 287 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, 288 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, 289 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, 290 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) 291 .legalForCartesianProduct(allIntScalars, allPtrs); 292 293 getActionDefinitionsBuilder( 294 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX}) 295 .legalForCartesianProduct(allFloatScalars, allPtrs); 296 297 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) 298 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs); 299 300 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); 301 // TODO: add proper legalization rules. 302 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); 303 304 getActionDefinitionsBuilder( 305 {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO}) 306 .alwaysLegal(); 307 308 // FP conversions. 309 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) 310 .legalForCartesianProduct(allFloatScalarsAndVectors); 311 312 // Pointer-handling. 313 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); 314 315 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. 316 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); 317 318 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to 319 // tighten these requirements. Many of these math functions are only legal on 320 // specific bitwidths, so they are not selectable for 321 // allFloatScalarsAndVectors. 322 getActionDefinitionsBuilder({G_STRICT_FSQRT, 323 G_FPOW, 324 G_FEXP, 325 G_FEXP2, 326 G_FLOG, 327 G_FLOG2, 328 G_FLOG10, 329 G_FABS, 330 G_FMINNUM, 331 G_FMAXNUM, 332 G_FCEIL, 333 G_FCOS, 334 G_FSIN, 335 G_FTAN, 336 G_FACOS, 337 G_FASIN, 338 G_FATAN, 339 G_FATAN2, 340 G_FCOSH, 341 G_FSINH, 342 G_FTANH, 343 G_FSQRT, 344 G_FFLOOR, 345 G_FRINT, 346 G_FNEARBYINT, 347 G_INTRINSIC_ROUND, 348 G_INTRINSIC_TRUNC, 349 G_FMINIMUM, 350 G_FMAXIMUM, 351 G_INTRINSIC_ROUNDEVEN}) 352 .legalFor(allFloatScalarsAndVectors); 353 354 getActionDefinitionsBuilder(G_FCOPYSIGN) 355 .legalForCartesianProduct(allFloatScalarsAndVectors, 356 allFloatScalarsAndVectors); 357 358 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( 359 allFloatScalarsAndVectors, allIntScalarsAndVectors); 360 361 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 362 getActionDefinitionsBuilder( 363 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) 364 .legalForCartesianProduct(allIntScalarsAndVectors, 365 allIntScalarsAndVectors); 366 367 // Struct return types become a single scalar, so cannot easily legalize. 368 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); 369 } 370 371 getLegacyLegalizerInfo().computeTables(); 372 verify(*ST.getInstrInfo()); 373 } 374 375 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, 376 LegalizerHelper &Helper, 377 MachineRegisterInfo &MRI, 378 SPIRVGlobalRegistry *GR) { 379 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); 380 MRI.setRegClass(ConvReg, GR->getRegClass(SpvType)); 381 GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF()); 382 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 383 .addDef(ConvReg) 384 .addUse(Reg); 385 return ConvReg; 386 } 387 388 bool SPIRVLegalizerInfo::legalizeCustom( 389 LegalizerHelper &Helper, MachineInstr &MI, 390 LostDebugLocObserver &LocObserver) const { 391 auto Opc = MI.getOpcode(); 392 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 393 if (!isTypeFoldingSupported(Opc)) { 394 assert(Opc == TargetOpcode::G_ICMP); 395 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 396 auto &Op0 = MI.getOperand(2); 397 auto &Op1 = MI.getOperand(3); 398 Register Reg0 = Op0.getReg(); 399 Register Reg1 = Op1.getReg(); 400 CmpInst::Predicate Cond = 401 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 402 if ((!ST->canDirectlyComparePointers() || 403 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && 404 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { 405 LLT ConvT = LLT::scalar(ST->getPointerSize()); 406 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), 407 ST->getPointerSize()); 408 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); 409 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); 410 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); 411 } 412 return true; 413 } 414 // TODO: implement legalization for other opcodes. 415 return true; 416 } 417