xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision b5132b7d044a5bc83eba9b09bd158cd77a511403)
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "SPIRVInstrInfo.h"
16 #include "SPIRVRegisterBankInfo.h"
17 #include "SPIRVRegisterInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "SPIRVTargetMachine.h"
20 #include "llvm/CodeGen/MachineInstrBuilder.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/IR/IntrinsicsSPIRV.h"
23 
24 #define DEBUG_TYPE "spirv-lower"
25 
26 using namespace llvm;
27 
28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30   // This code avoids CallLowering fail inside getVectorTypeBreakdown
31   // on v3i1 arguments. Maybe we need to return 1 for all types.
32   // TODO: remove it once this case is supported by the default implementation.
33   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34       (VT.getVectorElementType() == MVT::i1 ||
35        VT.getVectorElementType() == MVT::i8))
36     return 1;
37   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38     return 1;
39   return getNumRegisters(Context, VT);
40 }
41 
42 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
43                                                        CallingConv::ID CC,
44                                                        EVT VT) const {
45   // This code avoids CallLowering fail inside getVectorTypeBreakdown
46   // on v3i1 arguments. Maybe we need to return i32 for all types.
47   // TODO: remove it once this case is supported by the default implementation.
48   if (VT.isVector() && VT.getVectorNumElements() == 3) {
49     if (VT.getVectorElementType() == MVT::i1)
50       return MVT::v4i1;
51     else if (VT.getVectorElementType() == MVT::i8)
52       return MVT::v4i8;
53   }
54   return getRegisterType(Context, VT);
55 }
56 
57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58                                              const CallInst &I,
59                                              MachineFunction &MF,
60                                              unsigned Intrinsic) const {
61   unsigned AlignIdx = 3;
62   switch (Intrinsic) {
63   case Intrinsic::spv_load:
64     AlignIdx = 2;
65     [[fallthrough]];
66   case Intrinsic::spv_store: {
67     if (I.getNumOperands() >= AlignIdx + 1) {
68       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69       Info.align = Align(AlignOp->getZExtValue());
70     }
71     Info.flags = static_cast<MachineMemOperand::Flags>(
72         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73     Info.memVT = MVT::i64;
74     // TODO: take into account opaque pointers (don't use getElementType).
75     // MVT::getVT(PtrTy->getElementType());
76     return true;
77     break;
78   }
79   default:
80     break;
81   }
82   return false;
83 }
84 
85 std::pair<unsigned, const TargetRegisterClass *>
86 SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
87                                                   StringRef Constraint,
88                                                   MVT VT) const {
89   const TargetRegisterClass *RC = nullptr;
90   if (Constraint.starts_with("{"))
91     return std::make_pair(0u, RC);
92 
93   if (VT.isFloatingPoint())
94     RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
95   else if (VT.isInteger())
96     RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
97   else
98     RC = &SPIRV::iIDRegClass;
99 
100   return std::make_pair(0u, RC);
101 }
102 
103 inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
104   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
105   return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
106              ? TypeInst->getOperand(1).getReg()
107              : OpReg;
108 }
109 
110 static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
111                             SPIRVGlobalRegistry &GR, MachineInstr &I,
112                             Register OpReg, unsigned OpIdx,
113                             SPIRVType *NewPtrType) {
114   MachineIRBuilder MIB(I);
115   Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
116   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
117                  .addDef(NewReg)
118                  .addUse(GR.getSPIRVTypeID(NewPtrType))
119                  .addUse(OpReg)
120                  .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
121                                    *STI.getRegBankInfo());
122   if (!Res)
123     report_fatal_error("insert validation bitcast: cannot constrain all uses");
124   I.getOperand(OpIdx).setReg(NewReg);
125 }
126 
127 static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
128                                    SPIRVType *OpType, bool ReuseType,
129                                    bool EmitIR, SPIRVType *ResType,
130                                    const Type *ResTy) {
131   SPIRV::StorageClass::StorageClass SC =
132       static_cast<SPIRV::StorageClass::StorageClass>(
133           OpType->getOperand(1).getImm());
134   MachineIRBuilder MIB(I);
135   SPIRVType *NewBaseType =
136       ReuseType ? ResType
137                 : GR.getOrCreateSPIRVType(
138                       ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
139   return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
140 }
141 
142 // Insert a bitcast before the instruction to keep SPIR-V code valid
143 // when there is a type mismatch between results and operand types.
144 static void validatePtrTypes(const SPIRVSubtarget &STI,
145                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
146                              MachineInstr &I, unsigned OpIdx,
147                              SPIRVType *ResType, const Type *ResTy = nullptr) {
148   // Get operand type
149   MachineFunction *MF = I.getParent()->getParent();
150   Register OpReg = I.getOperand(OpIdx).getReg();
151   Register OpTypeReg = getTypeReg(MRI, OpReg);
152   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
153   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
154     return;
155   // Get operand's pointee type
156   Register ElemTypeReg = OpType->getOperand(2).getReg();
157   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
158   if (!ElemType)
159     return;
160   // Check if we need a bitcast to make a statement valid
161   bool IsSameMF = MF == ResType->getParent()->getParent();
162   bool IsEqualTypes = IsSameMF ? ElemType == ResType
163                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
164   if (IsEqualTypes)
165     return;
166   // There is a type mismatch between results and operand types
167   // and we insert a bitcast before the instruction to keep SPIR-V code valid
168   SPIRVType *NewPtrType =
169       createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
170   if (!GR.isBitcastCompatible(NewPtrType, OpType))
171     report_fatal_error(
172         "insert validation bitcast: incompatible result and operand types");
173   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
174 }
175 
176 // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
177 // that doesn't point to OpTypeEvent.
178 static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
179                                        MachineRegisterInfo *MRI,
180                                        SPIRVGlobalRegistry &GR,
181                                        MachineInstr &I) {
182   constexpr unsigned OpIdx = 2;
183   MachineFunction *MF = I.getParent()->getParent();
184   Register OpReg = I.getOperand(OpIdx).getReg();
185   Register OpTypeReg = getTypeReg(MRI, OpReg);
186   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
187   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
188     return;
189   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
190   if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
191     return;
192   // Insert a bitcast before the instruction to keep SPIR-V code valid.
193   LLVMContext &Context = MF->getFunction().getContext();
194   SPIRVType *NewPtrType =
195       createNewPtrType(GR, I, OpType, false, true, nullptr,
196                        TargetExtType::get(Context, "spirv.Event"));
197   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
198 }
199 
200 static void validateLifetimeStart(const SPIRVSubtarget &STI,
201                                   MachineRegisterInfo *MRI,
202                                   SPIRVGlobalRegistry &GR, MachineInstr &I) {
203   Register PtrReg = I.getOperand(0).getReg();
204   MachineFunction *MF = I.getParent()->getParent();
205   Register PtrTypeReg = getTypeReg(MRI, PtrReg);
206   SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
207   SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
208   if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
209       (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
210        PonteeElemType->getOperand(1).getImm() == 8))
211     return;
212   // To keep the code valid a bitcast must be inserted
213   SPIRV::StorageClass::StorageClass SC =
214       static_cast<SPIRV::StorageClass::StorageClass>(
215           PtrType->getOperand(1).getImm());
216   MachineIRBuilder MIB(I);
217   LLVMContext &Context = MF->getFunction().getContext();
218   SPIRVType *ElemType =
219       GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
220   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
221   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
222 }
223 
224 static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
225                                          MachineRegisterInfo *MRI,
226                                          SPIRVGlobalRegistry &GR,
227                                          MachineInstr &I, unsigned OpIdx) {
228   MachineFunction *MF = I.getParent()->getParent();
229   Register OpReg = I.getOperand(OpIdx).getReg();
230   Register OpTypeReg = getTypeReg(MRI, OpReg);
231   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
232   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
233     return;
234   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
235   if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
236       ElemType->getNumOperands() != 2)
237     return;
238   // It's a structure-wrapper around another type with a single member field.
239   SPIRVType *MemberType =
240       GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
241   if (!MemberType)
242     return;
243   unsigned MemberTypeOp = MemberType->getOpcode();
244   if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
245       MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
246     return;
247   // It's a structure-wrapper around a valid type. Insert a bitcast before the
248   // instruction to keep SPIR-V code valid.
249   SPIRV::StorageClass::StorageClass SC =
250       static_cast<SPIRV::StorageClass::StorageClass>(
251           OpType->getOperand(1).getImm());
252   MachineIRBuilder MIB(I);
253   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
254   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
255 }
256 
257 // Insert a bitcast before the function call instruction to keep SPIR-V code
258 // valid when there is a type mismatch between actual and expected types of an
259 // argument:
260 // %formal = OpFunctionParameter %formal_type
261 // ...
262 // %res = OpFunctionCall %ty %fun %actual ...
263 // implies that %actual is of %formal_type, and in case of opaque pointers.
264 // We may need to insert a bitcast to ensure this.
265 void validateFunCallMachineDef(const SPIRVSubtarget &STI,
266                                MachineRegisterInfo *DefMRI,
267                                MachineRegisterInfo *CallMRI,
268                                SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
269                                MachineInstr *FunDef) {
270   if (FunDef->getOpcode() != SPIRV::OpFunction)
271     return;
272   unsigned OpIdx = 3;
273   for (FunDef = FunDef->getNextNode();
274        FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
275        OpIdx < FunCall.getNumOperands();
276        FunDef = FunDef->getNextNode(), OpIdx++) {
277     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
278     SPIRVType *DefElemType =
279         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
280             ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
281                                      DefPtrType->getParent()->getParent())
282             : nullptr;
283     if (DefElemType) {
284       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
285       // validatePtrTypes() works in the context if the call site
286       // When we process historical records about forward calls
287       // we need to switch context to the (forward) call site and
288       // then restore it back to the current machine function.
289       MachineFunction *CurMF =
290           GR.setCurrentFunc(*FunCall.getParent()->getParent());
291       validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
292                        DefElemTy);
293       GR.setCurrentFunc(*CurMF);
294     }
295   }
296 }
297 
298 // Ensure there is no mismatch between actual and expected arg types: calls
299 // with a processed definition. Return Function pointer if it's a forward
300 // call (ahead of definition), and nullptr otherwise.
301 const Function *validateFunCall(const SPIRVSubtarget &STI,
302                                 MachineRegisterInfo *CallMRI,
303                                 SPIRVGlobalRegistry &GR,
304                                 MachineInstr &FunCall) {
305   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
306   const Function *F = dyn_cast<Function>(GV);
307   MachineInstr *FunDef =
308       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
309   if (!FunDef)
310     return F;
311   MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
312   validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
313   return nullptr;
314 }
315 
316 // Ensure there is no mismatch between actual and expected arg types: calls
317 // ahead of a processed definition.
318 void validateForwardCalls(const SPIRVSubtarget &STI,
319                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
320                           MachineInstr &FunDef) {
321   const Function *F = GR.getFunctionByDefinition(&FunDef);
322   if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
323     for (MachineInstr *FunCall : *FwdCalls) {
324       MachineRegisterInfo *CallMRI =
325           &FunCall->getParent()->getParent()->getRegInfo();
326       validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
327     }
328 }
329 
330 // Validation of an access chain.
331 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
332                          SPIRVGlobalRegistry &GR, MachineInstr &I) {
333   SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
334   if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
335     SPIRVType *BaseElemType =
336         GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
337     validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
338   }
339 }
340 
341 // TODO: the logic of inserting additional bitcast's is to be moved
342 // to pre-IRTranslation passes eventually
343 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
344   // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
345   // We'd like to avoid the needless second processing pass.
346   if (ProcessedMF.find(&MF) != ProcessedMF.end())
347     return;
348 
349   MachineRegisterInfo *MRI = &MF.getRegInfo();
350   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
351   GR.setCurrentFunc(MF);
352   for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
353     MachineBasicBlock *MBB = &*I;
354     SmallPtrSet<MachineInstr *, 8> ToMove;
355     for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
356          MBBI != MBBE;) {
357       MachineInstr &MI = *MBBI++;
358       switch (MI.getOpcode()) {
359       case SPIRV::OpAtomicLoad:
360       case SPIRV::OpAtomicExchange:
361       case SPIRV::OpAtomicCompareExchange:
362       case SPIRV::OpAtomicCompareExchangeWeak:
363       case SPIRV::OpAtomicIIncrement:
364       case SPIRV::OpAtomicIDecrement:
365       case SPIRV::OpAtomicIAdd:
366       case SPIRV::OpAtomicISub:
367       case SPIRV::OpAtomicSMin:
368       case SPIRV::OpAtomicUMin:
369       case SPIRV::OpAtomicSMax:
370       case SPIRV::OpAtomicUMax:
371       case SPIRV::OpAtomicAnd:
372       case SPIRV::OpAtomicOr:
373       case SPIRV::OpAtomicXor:
374         // for the above listed instructions
375         // OpAtomicXXX <ResType>, ptr %Op, ...
376         // implies that %Op is a pointer to <ResType>
377       case SPIRV::OpLoad:
378         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
379         validatePtrTypes(STI, MRI, GR, MI, 2,
380                          GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
381         break;
382       case SPIRV::OpAtomicStore:
383         // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
384         // implies that %Op points to the <Obj>'s type
385         validatePtrTypes(STI, MRI, GR, MI, 0,
386                          GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
387         break;
388       case SPIRV::OpStore:
389         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
390         validatePtrTypes(STI, MRI, GR, MI, 0,
391                          GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
392         break;
393       case SPIRV::OpPtrCastToGeneric:
394       case SPIRV::OpGenericCastToPtr:
395         validateAccessChain(STI, MRI, GR, MI);
396         break;
397       case SPIRV::OpPtrAccessChain:
398       case SPIRV::OpInBoundsPtrAccessChain:
399         if (MI.getNumOperands() == 4)
400           validateAccessChain(STI, MRI, GR, MI);
401         break;
402 
403       case SPIRV::OpFunctionCall:
404         // ensure there is no mismatch between actual and expected arg types:
405         // calls with a processed definition
406         if (MI.getNumOperands() > 3)
407           if (const Function *F = validateFunCall(STI, MRI, GR, MI))
408             GR.addForwardCall(F, &MI);
409         break;
410       case SPIRV::OpFunction:
411         // ensure there is no mismatch between actual and expected arg types:
412         // calls ahead of a processed definition
413         validateForwardCalls(STI, MRI, GR, MI);
414         break;
415 
416       // ensure that LLVM IR add/sub instructions result in logical SPIR-V
417       // instructions when applied to bool type
418       case SPIRV::OpIAddS:
419       case SPIRV::OpIAddV:
420       case SPIRV::OpISubS:
421       case SPIRV::OpISubV:
422         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
423                                       SPIRV::OpTypeBool))
424           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
425         break;
426 
427       // ensure that LLVM IR bitwise instructions result in logical SPIR-V
428       // instructions when applied to bool type
429       case SPIRV::OpBitwiseOrS:
430       case SPIRV::OpBitwiseOrV:
431         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
432                                       SPIRV::OpTypeBool))
433           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
434         break;
435       case SPIRV::OpBitwiseAndS:
436       case SPIRV::OpBitwiseAndV:
437         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
438                                       SPIRV::OpTypeBool))
439           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
440         break;
441       case SPIRV::OpBitwiseXorS:
442       case SPIRV::OpBitwiseXorV:
443         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
444                                       SPIRV::OpTypeBool))
445           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
446         break;
447       case SPIRV::OpLifetimeStart:
448       case SPIRV::OpLifetimeStop:
449         if (MI.getOperand(1).getImm() > 0)
450           validateLifetimeStart(STI, MRI, GR, MI);
451         break;
452       case SPIRV::OpGroupAsyncCopy:
453         validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
454         validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
455         break;
456       case SPIRV::OpGroupWaitEvents:
457         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
458         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
459         break;
460       case SPIRV::OpConstantI: {
461         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
462         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
463             MI.getOperand(2).getImm() == 0) {
464           // Validate the null constant of a target extension type
465           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
466           for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
467             MI.removeOperand(i);
468         }
469       } break;
470       case SPIRV::OpPhi: {
471         // Phi refers to a type definition that goes after the Phi
472         // instruction, so that the virtual register definition of the type
473         // doesn't dominate all uses. Let's place the type definition
474         // instruction at the end of the predecessor.
475         MachineBasicBlock *Curr = MI.getParent();
476         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
477         if (Type->getParent() == Curr && !Curr->pred_empty())
478           ToMove.insert(const_cast<MachineInstr *>(Type));
479       } break;
480       case SPIRV::OpExtInst: {
481         // prefetch
482         if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
483             MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
484           continue;
485         switch (MI.getOperand(3).getImm()) {
486         case SPIRV::OpenCLExtInst::frexp:
487         case SPIRV::OpenCLExtInst::lgamma_r:
488         case SPIRV::OpenCLExtInst::remquo: {
489           // The last operand must be of a pointer to i32 or vector of i32
490           // values.
491           MachineIRBuilder MIB(MI);
492           SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
493           SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
494           assert(RetType && "Expected return type");
495           validatePtrTypes(
496               STI, MRI, GR, MI, MI.getNumOperands() - 1,
497               RetType->getOpcode() != SPIRV::OpTypeVector
498                   ? Int32Type
499                   : GR.getOrCreateSPIRVVectorType(
500                         Int32Type, RetType->getOperand(2).getImm(), MIB));
501         } break;
502         case SPIRV::OpenCLExtInst::fract:
503         case SPIRV::OpenCLExtInst::modf:
504         case SPIRV::OpenCLExtInst::sincos:
505           // The last operand must be of a pointer to the base type represented
506           // by the previous operand.
507           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
508                  "Expected v-reg");
509           validatePtrTypes(
510               STI, MRI, GR, MI, MI.getNumOperands() - 1,
511               GR.getSPIRVTypeForVReg(
512                   MI.getOperand(MI.getNumOperands() - 2).getReg()));
513           break;
514         case SPIRV::OpenCLExtInst::prefetch:
515           // Expected `ptr` type is a pointer to float, integer or vector, but
516           // the pontee value can be wrapped into a struct.
517           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
518                  "Expected v-reg");
519           validatePtrUnwrapStructField(STI, MRI, GR, MI,
520                                        MI.getNumOperands() - 2);
521           break;
522         }
523       } break;
524       }
525     }
526     for (MachineInstr *MI : ToMove) {
527       MachineBasicBlock *Curr = MI->getParent();
528       MachineBasicBlock *Pred = *Curr->pred_begin();
529       Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));
530     }
531   }
532   ProcessedMF.insert(&MF);
533   TargetLowering::finalizeLowering(MF);
534 }
535