1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIRVTargetLowering class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVISelLowering.h" 14 #include "SPIRV.h" 15 #include "SPIRVInstrInfo.h" 16 #include "SPIRVRegisterBankInfo.h" 17 #include "SPIRVRegisterInfo.h" 18 #include "SPIRVSubtarget.h" 19 #include "SPIRVTargetMachine.h" 20 #include "llvm/CodeGen/MachineInstrBuilder.h" 21 #include "llvm/CodeGen/MachineRegisterInfo.h" 22 #include "llvm/IR/IntrinsicsSPIRV.h" 23 24 #define DEBUG_TYPE "spirv-lower" 25 26 using namespace llvm; 27 28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( 29 LLVMContext &Context, CallingConv::ID CC, EVT VT) const { 30 // This code avoids CallLowering fail inside getVectorTypeBreakdown 31 // on v3i1 arguments. Maybe we need to return 1 for all types. 32 // TODO: remove it once this case is supported by the default implementation. 33 if (VT.isVector() && VT.getVectorNumElements() == 3 && 34 (VT.getVectorElementType() == MVT::i1 || 35 VT.getVectorElementType() == MVT::i8)) 36 return 1; 37 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) 38 return 1; 39 return getNumRegisters(Context, VT); 40 } 41 42 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, 43 CallingConv::ID CC, 44 EVT VT) const { 45 // This code avoids CallLowering fail inside getVectorTypeBreakdown 46 // on v3i1 arguments. Maybe we need to return i32 for all types. 47 // TODO: remove it once this case is supported by the default implementation. 48 if (VT.isVector() && VT.getVectorNumElements() == 3) { 49 if (VT.getVectorElementType() == MVT::i1) 50 return MVT::v4i1; 51 else if (VT.getVectorElementType() == MVT::i8) 52 return MVT::v4i8; 53 } 54 return getRegisterType(Context, VT); 55 } 56 57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, 58 const CallInst &I, 59 MachineFunction &MF, 60 unsigned Intrinsic) const { 61 unsigned AlignIdx = 3; 62 switch (Intrinsic) { 63 case Intrinsic::spv_load: 64 AlignIdx = 2; 65 [[fallthrough]]; 66 case Intrinsic::spv_store: { 67 if (I.getNumOperands() >= AlignIdx + 1) { 68 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx)); 69 Info.align = Align(AlignOp->getZExtValue()); 70 } 71 Info.flags = static_cast<MachineMemOperand::Flags>( 72 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue()); 73 Info.memVT = MVT::i64; 74 // TODO: take into account opaque pointers (don't use getElementType). 75 // MVT::getVT(PtrTy->getElementType()); 76 return true; 77 break; 78 } 79 default: 80 break; 81 } 82 return false; 83 } 84 85 std::pair<unsigned, const TargetRegisterClass *> 86 SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, 87 StringRef Constraint, 88 MVT VT) const { 89 const TargetRegisterClass *RC = nullptr; 90 if (Constraint.starts_with("{")) 91 return std::make_pair(0u, RC); 92 93 if (VT.isFloatingPoint()) 94 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass; 95 else if (VT.isInteger()) 96 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass; 97 else 98 RC = &SPIRV::iIDRegClass; 99 100 return std::make_pair(0u, RC); 101 } 102 103 inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) { 104 SPIRVType *TypeInst = MRI->getVRegDef(OpReg); 105 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter 106 ? TypeInst->getOperand(1).getReg() 107 : OpReg; 108 } 109 110 static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 111 SPIRVGlobalRegistry &GR, MachineInstr &I, 112 Register OpReg, unsigned OpIdx, 113 SPIRVType *NewPtrType) { 114 MachineIRBuilder MIB(I); 115 Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF()); 116 bool Res = MIB.buildInstr(SPIRV::OpBitcast) 117 .addDef(NewReg) 118 .addUse(GR.getSPIRVTypeID(NewPtrType)) 119 .addUse(OpReg) 120 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), 121 *STI.getRegBankInfo()); 122 if (!Res) 123 report_fatal_error("insert validation bitcast: cannot constrain all uses"); 124 I.getOperand(OpIdx).setReg(NewReg); 125 } 126 127 static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, 128 SPIRVType *OpType, bool ReuseType, 129 bool EmitIR, SPIRVType *ResType, 130 const Type *ResTy) { 131 SPIRV::StorageClass::StorageClass SC = 132 static_cast<SPIRV::StorageClass::StorageClass>( 133 OpType->getOperand(1).getImm()); 134 MachineIRBuilder MIB(I); 135 SPIRVType *NewBaseType = 136 ReuseType ? ResType 137 : GR.getOrCreateSPIRVType( 138 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR); 139 return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); 140 } 141 142 // Insert a bitcast before the instruction to keep SPIR-V code valid 143 // when there is a type mismatch between results and operand types. 144 static void validatePtrTypes(const SPIRVSubtarget &STI, 145 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, 146 MachineInstr &I, unsigned OpIdx, 147 SPIRVType *ResType, const Type *ResTy = nullptr) { 148 // Get operand type 149 MachineFunction *MF = I.getParent()->getParent(); 150 Register OpReg = I.getOperand(OpIdx).getReg(); 151 Register OpTypeReg = getTypeReg(MRI, OpReg); 152 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 153 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 154 return; 155 // Get operand's pointee type 156 Register ElemTypeReg = OpType->getOperand(2).getReg(); 157 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF); 158 if (!ElemType) 159 return; 160 // Check if we need a bitcast to make a statement valid 161 bool IsSameMF = MF == ResType->getParent()->getParent(); 162 bool IsEqualTypes = IsSameMF ? ElemType == ResType 163 : GR.getTypeForSPIRVType(ElemType) == ResTy; 164 if (IsEqualTypes) 165 return; 166 // There is a type mismatch between results and operand types 167 // and we insert a bitcast before the instruction to keep SPIR-V code valid 168 SPIRVType *NewPtrType = 169 createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy); 170 if (!GR.isBitcastCompatible(NewPtrType, OpType)) 171 report_fatal_error( 172 "insert validation bitcast: incompatible result and operand types"); 173 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 174 } 175 176 // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer 177 // that doesn't point to OpTypeEvent. 178 static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, 179 MachineRegisterInfo *MRI, 180 SPIRVGlobalRegistry &GR, 181 MachineInstr &I) { 182 constexpr unsigned OpIdx = 2; 183 MachineFunction *MF = I.getParent()->getParent(); 184 Register OpReg = I.getOperand(OpIdx).getReg(); 185 Register OpTypeReg = getTypeReg(MRI, OpReg); 186 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 187 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 188 return; 189 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 190 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent) 191 return; 192 // Insert a bitcast before the instruction to keep SPIR-V code valid. 193 LLVMContext &Context = MF->getFunction().getContext(); 194 SPIRVType *NewPtrType = 195 createNewPtrType(GR, I, OpType, false, true, nullptr, 196 TargetExtType::get(Context, "spirv.Event")); 197 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 198 } 199 200 static void validateLifetimeStart(const SPIRVSubtarget &STI, 201 MachineRegisterInfo *MRI, 202 SPIRVGlobalRegistry &GR, MachineInstr &I) { 203 Register PtrReg = I.getOperand(0).getReg(); 204 MachineFunction *MF = I.getParent()->getParent(); 205 Register PtrTypeReg = getTypeReg(MRI, PtrReg); 206 SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF); 207 SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr; 208 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid || 209 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt && 210 PonteeElemType->getOperand(1).getImm() == 8)) 211 return; 212 // To keep the code valid a bitcast must be inserted 213 SPIRV::StorageClass::StorageClass SC = 214 static_cast<SPIRV::StorageClass::StorageClass>( 215 PtrType->getOperand(1).getImm()); 216 MachineIRBuilder MIB(I); 217 LLVMContext &Context = MF->getFunction().getContext(); 218 SPIRVType *ElemType = 219 GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB); 220 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC); 221 doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType); 222 } 223 224 static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, 225 MachineRegisterInfo *MRI, 226 SPIRVGlobalRegistry &GR, 227 MachineInstr &I, unsigned OpIdx) { 228 MachineFunction *MF = I.getParent()->getParent(); 229 Register OpReg = I.getOperand(OpIdx).getReg(); 230 Register OpTypeReg = getTypeReg(MRI, OpReg); 231 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 232 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 233 return; 234 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 235 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct || 236 ElemType->getNumOperands() != 2) 237 return; 238 // It's a structure-wrapper around another type with a single member field. 239 SPIRVType *MemberType = 240 GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg()); 241 if (!MemberType) 242 return; 243 unsigned MemberTypeOp = MemberType->getOpcode(); 244 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt && 245 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool) 246 return; 247 // It's a structure-wrapper around a valid type. Insert a bitcast before the 248 // instruction to keep SPIR-V code valid. 249 SPIRV::StorageClass::StorageClass SC = 250 static_cast<SPIRV::StorageClass::StorageClass>( 251 OpType->getOperand(1).getImm()); 252 MachineIRBuilder MIB(I); 253 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC); 254 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 255 } 256 257 // Insert a bitcast before the function call instruction to keep SPIR-V code 258 // valid when there is a type mismatch between actual and expected types of an 259 // argument: 260 // %formal = OpFunctionParameter %formal_type 261 // ... 262 // %res = OpFunctionCall %ty %fun %actual ... 263 // implies that %actual is of %formal_type, and in case of opaque pointers. 264 // We may need to insert a bitcast to ensure this. 265 void validateFunCallMachineDef(const SPIRVSubtarget &STI, 266 MachineRegisterInfo *DefMRI, 267 MachineRegisterInfo *CallMRI, 268 SPIRVGlobalRegistry &GR, MachineInstr &FunCall, 269 MachineInstr *FunDef) { 270 if (FunDef->getOpcode() != SPIRV::OpFunction) 271 return; 272 unsigned OpIdx = 3; 273 for (FunDef = FunDef->getNextNode(); 274 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && 275 OpIdx < FunCall.getNumOperands(); 276 FunDef = FunDef->getNextNode(), OpIdx++) { 277 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); 278 SPIRVType *DefElemType = 279 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer 280 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(), 281 DefPtrType->getParent()->getParent()) 282 : nullptr; 283 if (DefElemType) { 284 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); 285 // validatePtrTypes() works in the context if the call site 286 // When we process historical records about forward calls 287 // we need to switch context to the (forward) call site and 288 // then restore it back to the current machine function. 289 MachineFunction *CurMF = 290 GR.setCurrentFunc(*FunCall.getParent()->getParent()); 291 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, 292 DefElemTy); 293 GR.setCurrentFunc(*CurMF); 294 } 295 } 296 } 297 298 // Ensure there is no mismatch between actual and expected arg types: calls 299 // with a processed definition. Return Function pointer if it's a forward 300 // call (ahead of definition), and nullptr otherwise. 301 const Function *validateFunCall(const SPIRVSubtarget &STI, 302 MachineRegisterInfo *CallMRI, 303 SPIRVGlobalRegistry &GR, 304 MachineInstr &FunCall) { 305 const GlobalValue *GV = FunCall.getOperand(2).getGlobal(); 306 const Function *F = dyn_cast<Function>(GV); 307 MachineInstr *FunDef = 308 const_cast<MachineInstr *>(GR.getFunctionDefinition(F)); 309 if (!FunDef) 310 return F; 311 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); 312 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); 313 return nullptr; 314 } 315 316 // Ensure there is no mismatch between actual and expected arg types: calls 317 // ahead of a processed definition. 318 void validateForwardCalls(const SPIRVSubtarget &STI, 319 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, 320 MachineInstr &FunDef) { 321 const Function *F = GR.getFunctionByDefinition(&FunDef); 322 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F)) 323 for (MachineInstr *FunCall : *FwdCalls) { 324 MachineRegisterInfo *CallMRI = 325 &FunCall->getParent()->getParent()->getRegInfo(); 326 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef); 327 } 328 } 329 330 // Validation of an access chain. 331 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 332 SPIRVGlobalRegistry &GR, MachineInstr &I) { 333 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg()); 334 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) { 335 SPIRVType *BaseElemType = 336 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg()); 337 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType); 338 } 339 } 340 341 // TODO: the logic of inserting additional bitcast's is to be moved 342 // to pre-IRTranslation passes eventually 343 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { 344 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) 345 // We'd like to avoid the needless second processing pass. 346 if (ProcessedMF.find(&MF) != ProcessedMF.end()) 347 return; 348 349 MachineRegisterInfo *MRI = &MF.getRegInfo(); 350 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); 351 GR.setCurrentFunc(MF); 352 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { 353 MachineBasicBlock *MBB = &*I; 354 SmallPtrSet<MachineInstr *, 8> ToMove; 355 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); 356 MBBI != MBBE;) { 357 MachineInstr &MI = *MBBI++; 358 switch (MI.getOpcode()) { 359 case SPIRV::OpAtomicLoad: 360 case SPIRV::OpAtomicExchange: 361 case SPIRV::OpAtomicCompareExchange: 362 case SPIRV::OpAtomicCompareExchangeWeak: 363 case SPIRV::OpAtomicIIncrement: 364 case SPIRV::OpAtomicIDecrement: 365 case SPIRV::OpAtomicIAdd: 366 case SPIRV::OpAtomicISub: 367 case SPIRV::OpAtomicSMin: 368 case SPIRV::OpAtomicUMin: 369 case SPIRV::OpAtomicSMax: 370 case SPIRV::OpAtomicUMax: 371 case SPIRV::OpAtomicAnd: 372 case SPIRV::OpAtomicOr: 373 case SPIRV::OpAtomicXor: 374 // for the above listed instructions 375 // OpAtomicXXX <ResType>, ptr %Op, ... 376 // implies that %Op is a pointer to <ResType> 377 case SPIRV::OpLoad: 378 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> 379 validatePtrTypes(STI, MRI, GR, MI, 2, 380 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 381 break; 382 case SPIRV::OpAtomicStore: 383 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> 384 // implies that %Op points to the <Obj>'s type 385 validatePtrTypes(STI, MRI, GR, MI, 0, 386 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg())); 387 break; 388 case SPIRV::OpStore: 389 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type 390 validatePtrTypes(STI, MRI, GR, MI, 0, 391 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg())); 392 break; 393 case SPIRV::OpPtrCastToGeneric: 394 case SPIRV::OpGenericCastToPtr: 395 validateAccessChain(STI, MRI, GR, MI); 396 break; 397 case SPIRV::OpPtrAccessChain: 398 case SPIRV::OpInBoundsPtrAccessChain: 399 if (MI.getNumOperands() == 4) 400 validateAccessChain(STI, MRI, GR, MI); 401 break; 402 403 case SPIRV::OpFunctionCall: 404 // ensure there is no mismatch between actual and expected arg types: 405 // calls with a processed definition 406 if (MI.getNumOperands() > 3) 407 if (const Function *F = validateFunCall(STI, MRI, GR, MI)) 408 GR.addForwardCall(F, &MI); 409 break; 410 case SPIRV::OpFunction: 411 // ensure there is no mismatch between actual and expected arg types: 412 // calls ahead of a processed definition 413 validateForwardCalls(STI, MRI, GR, MI); 414 break; 415 416 // ensure that LLVM IR add/sub instructions result in logical SPIR-V 417 // instructions when applied to bool type 418 case SPIRV::OpIAddS: 419 case SPIRV::OpIAddV: 420 case SPIRV::OpISubS: 421 case SPIRV::OpISubV: 422 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 423 SPIRV::OpTypeBool)) 424 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); 425 break; 426 427 // ensure that LLVM IR bitwise instructions result in logical SPIR-V 428 // instructions when applied to bool type 429 case SPIRV::OpBitwiseOrS: 430 case SPIRV::OpBitwiseOrV: 431 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 432 SPIRV::OpTypeBool)) 433 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr)); 434 break; 435 case SPIRV::OpBitwiseAndS: 436 case SPIRV::OpBitwiseAndV: 437 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 438 SPIRV::OpTypeBool)) 439 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd)); 440 break; 441 case SPIRV::OpBitwiseXorS: 442 case SPIRV::OpBitwiseXorV: 443 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 444 SPIRV::OpTypeBool)) 445 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); 446 break; 447 case SPIRV::OpLifetimeStart: 448 case SPIRV::OpLifetimeStop: 449 if (MI.getOperand(1).getImm() > 0) 450 validateLifetimeStart(STI, MRI, GR, MI); 451 break; 452 case SPIRV::OpGroupAsyncCopy: 453 validatePtrUnwrapStructField(STI, MRI, GR, MI, 3); 454 validatePtrUnwrapStructField(STI, MRI, GR, MI, 4); 455 break; 456 case SPIRV::OpGroupWaitEvents: 457 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> 458 validateGroupWaitEventsPtr(STI, MRI, GR, MI); 459 break; 460 case SPIRV::OpConstantI: { 461 SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); 462 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() && 463 MI.getOperand(2).getImm() == 0) { 464 // Validate the null constant of a target extension type 465 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull)); 466 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) 467 MI.removeOperand(i); 468 } 469 } break; 470 case SPIRV::OpPhi: { 471 // Phi refers to a type definition that goes after the Phi 472 // instruction, so that the virtual register definition of the type 473 // doesn't dominate all uses. Let's place the type definition 474 // instruction at the end of the predecessor. 475 MachineBasicBlock *Curr = MI.getParent(); 476 SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); 477 if (Type->getParent() == Curr && !Curr->pred_empty()) 478 ToMove.insert(const_cast<MachineInstr *>(Type)); 479 } break; 480 case SPIRV::OpExtInst: { 481 // prefetch 482 if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() || 483 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std) 484 continue; 485 switch (MI.getOperand(3).getImm()) { 486 case SPIRV::OpenCLExtInst::frexp: 487 case SPIRV::OpenCLExtInst::lgamma_r: 488 case SPIRV::OpenCLExtInst::remquo: { 489 // The last operand must be of a pointer to i32 or vector of i32 490 // values. 491 MachineIRBuilder MIB(MI); 492 SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB); 493 SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg()); 494 assert(RetType && "Expected return type"); 495 validatePtrTypes( 496 STI, MRI, GR, MI, MI.getNumOperands() - 1, 497 RetType->getOpcode() != SPIRV::OpTypeVector 498 ? Int32Type 499 : GR.getOrCreateSPIRVVectorType( 500 Int32Type, RetType->getOperand(2).getImm(), MIB)); 501 } break; 502 case SPIRV::OpenCLExtInst::fract: 503 case SPIRV::OpenCLExtInst::modf: 504 case SPIRV::OpenCLExtInst::sincos: 505 // The last operand must be of a pointer to the base type represented 506 // by the previous operand. 507 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && 508 "Expected v-reg"); 509 validatePtrTypes( 510 STI, MRI, GR, MI, MI.getNumOperands() - 1, 511 GR.getSPIRVTypeForVReg( 512 MI.getOperand(MI.getNumOperands() - 2).getReg())); 513 break; 514 case SPIRV::OpenCLExtInst::prefetch: 515 // Expected `ptr` type is a pointer to float, integer or vector, but 516 // the pontee value can be wrapped into a struct. 517 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && 518 "Expected v-reg"); 519 validatePtrUnwrapStructField(STI, MRI, GR, MI, 520 MI.getNumOperands() - 2); 521 break; 522 } 523 } break; 524 } 525 } 526 for (MachineInstr *MI : ToMove) { 527 MachineBasicBlock *Curr = MI->getParent(); 528 MachineBasicBlock *Pred = *Curr->pred_begin(); 529 Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI)); 530 } 531 } 532 ProcessedMF.insert(&MF); 533 TargetLowering::finalizeLowering(MF); 534 } 535