xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
181ad6265SDimitry Andric //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This file implements the SPIRVTargetLowering class.
1081ad6265SDimitry Andric //
1181ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1281ad6265SDimitry Andric 
1381ad6265SDimitry Andric #include "SPIRVISelLowering.h"
1481ad6265SDimitry Andric #include "SPIRV.h"
15*0fca6ea1SDimitry Andric #include "SPIRVInstrInfo.h"
16*0fca6ea1SDimitry Andric #include "SPIRVRegisterBankInfo.h"
17*0fca6ea1SDimitry Andric #include "SPIRVRegisterInfo.h"
18*0fca6ea1SDimitry Andric #include "SPIRVSubtarget.h"
19*0fca6ea1SDimitry Andric #include "SPIRVTargetMachine.h"
20*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineInstrBuilder.h"
21*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
22bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
2381ad6265SDimitry Andric 
2481ad6265SDimitry Andric #define DEBUG_TYPE "spirv-lower"
2581ad6265SDimitry Andric 
2681ad6265SDimitry Andric using namespace llvm;
2781ad6265SDimitry Andric 
2881ad6265SDimitry Andric unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
2981ad6265SDimitry Andric     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
3081ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
3181ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return 1 for all types.
3281ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
3381ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
3481ad6265SDimitry Andric       (VT.getVectorElementType() == MVT::i1 ||
3581ad6265SDimitry Andric        VT.getVectorElementType() == MVT::i8))
3681ad6265SDimitry Andric     return 1;
3706c3fb27SDimitry Andric   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
3806c3fb27SDimitry Andric     return 1;
3981ad6265SDimitry Andric   return getNumRegisters(Context, VT);
4081ad6265SDimitry Andric }
4181ad6265SDimitry Andric 
4281ad6265SDimitry Andric MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
4381ad6265SDimitry Andric                                                        CallingConv::ID CC,
4481ad6265SDimitry Andric                                                        EVT VT) const {
4581ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
4681ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return i32 for all types.
4781ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
4881ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3) {
4981ad6265SDimitry Andric     if (VT.getVectorElementType() == MVT::i1)
5081ad6265SDimitry Andric       return MVT::v4i1;
5181ad6265SDimitry Andric     else if (VT.getVectorElementType() == MVT::i8)
5281ad6265SDimitry Andric       return MVT::v4i8;
5381ad6265SDimitry Andric   }
5481ad6265SDimitry Andric   return getRegisterType(Context, VT);
5581ad6265SDimitry Andric }
56bdd1243dSDimitry Andric 
57bdd1243dSDimitry Andric bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58bdd1243dSDimitry Andric                                              const CallInst &I,
59bdd1243dSDimitry Andric                                              MachineFunction &MF,
60bdd1243dSDimitry Andric                                              unsigned Intrinsic) const {
61bdd1243dSDimitry Andric   unsigned AlignIdx = 3;
62bdd1243dSDimitry Andric   switch (Intrinsic) {
63bdd1243dSDimitry Andric   case Intrinsic::spv_load:
64bdd1243dSDimitry Andric     AlignIdx = 2;
6506c3fb27SDimitry Andric     [[fallthrough]];
66bdd1243dSDimitry Andric   case Intrinsic::spv_store: {
67bdd1243dSDimitry Andric     if (I.getNumOperands() >= AlignIdx + 1) {
68bdd1243dSDimitry Andric       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69bdd1243dSDimitry Andric       Info.align = Align(AlignOp->getZExtValue());
70bdd1243dSDimitry Andric     }
71bdd1243dSDimitry Andric     Info.flags = static_cast<MachineMemOperand::Flags>(
72bdd1243dSDimitry Andric         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73bdd1243dSDimitry Andric     Info.memVT = MVT::i64;
74bdd1243dSDimitry Andric     // TODO: take into account opaque pointers (don't use getElementType).
75bdd1243dSDimitry Andric     // MVT::getVT(PtrTy->getElementType());
76bdd1243dSDimitry Andric     return true;
77bdd1243dSDimitry Andric     break;
78bdd1243dSDimitry Andric   }
79bdd1243dSDimitry Andric   default:
80bdd1243dSDimitry Andric     break;
81bdd1243dSDimitry Andric   }
82bdd1243dSDimitry Andric   return false;
83bdd1243dSDimitry Andric }
84*0fca6ea1SDimitry Andric 
85*0fca6ea1SDimitry Andric std::pair<unsigned, const TargetRegisterClass *>
86*0fca6ea1SDimitry Andric SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
87*0fca6ea1SDimitry Andric                                                   StringRef Constraint,
88*0fca6ea1SDimitry Andric                                                   MVT VT) const {
89*0fca6ea1SDimitry Andric   const TargetRegisterClass *RC = nullptr;
90*0fca6ea1SDimitry Andric   if (Constraint.starts_with("{"))
91*0fca6ea1SDimitry Andric     return std::make_pair(0u, RC);
92*0fca6ea1SDimitry Andric 
93*0fca6ea1SDimitry Andric   if (VT.isFloatingPoint())
94*0fca6ea1SDimitry Andric     RC = VT.isVector() ? &SPIRV::vfIDRegClass
95*0fca6ea1SDimitry Andric                        : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass
96*0fca6ea1SDimitry Andric                                                         : &SPIRV::fIDRegClass);
97*0fca6ea1SDimitry Andric   else if (VT.isInteger())
98*0fca6ea1SDimitry Andric     RC = VT.isVector() ? &SPIRV::vIDRegClass
99*0fca6ea1SDimitry Andric                        : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass
100*0fca6ea1SDimitry Andric                                                         : &SPIRV::IDRegClass);
101*0fca6ea1SDimitry Andric   else
102*0fca6ea1SDimitry Andric     RC = &SPIRV::IDRegClass;
103*0fca6ea1SDimitry Andric 
104*0fca6ea1SDimitry Andric   return std::make_pair(0u, RC);
105*0fca6ea1SDimitry Andric }
106*0fca6ea1SDimitry Andric 
107*0fca6ea1SDimitry Andric inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
108*0fca6ea1SDimitry Andric   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
109*0fca6ea1SDimitry Andric   return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
110*0fca6ea1SDimitry Andric              ? TypeInst->getOperand(1).getReg()
111*0fca6ea1SDimitry Andric              : OpReg;
112*0fca6ea1SDimitry Andric }
113*0fca6ea1SDimitry Andric 
114*0fca6ea1SDimitry Andric static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
115*0fca6ea1SDimitry Andric                             SPIRVGlobalRegistry &GR, MachineInstr &I,
116*0fca6ea1SDimitry Andric                             Register OpReg, unsigned OpIdx,
117*0fca6ea1SDimitry Andric                             SPIRVType *NewPtrType) {
118*0fca6ea1SDimitry Andric   Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
119*0fca6ea1SDimitry Andric   MachineIRBuilder MIB(I);
120*0fca6ea1SDimitry Andric   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
121*0fca6ea1SDimitry Andric                  .addDef(NewReg)
122*0fca6ea1SDimitry Andric                  .addUse(GR.getSPIRVTypeID(NewPtrType))
123*0fca6ea1SDimitry Andric                  .addUse(OpReg)
124*0fca6ea1SDimitry Andric                  .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
125*0fca6ea1SDimitry Andric                                    *STI.getRegBankInfo());
126*0fca6ea1SDimitry Andric   if (!Res)
127*0fca6ea1SDimitry Andric     report_fatal_error("insert validation bitcast: cannot constrain all uses");
128*0fca6ea1SDimitry Andric   MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
129*0fca6ea1SDimitry Andric   GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
130*0fca6ea1SDimitry Andric   I.getOperand(OpIdx).setReg(NewReg);
131*0fca6ea1SDimitry Andric }
132*0fca6ea1SDimitry Andric 
133*0fca6ea1SDimitry Andric static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
134*0fca6ea1SDimitry Andric                                    SPIRVType *OpType, bool ReuseType,
135*0fca6ea1SDimitry Andric                                    bool EmitIR, SPIRVType *ResType,
136*0fca6ea1SDimitry Andric                                    const Type *ResTy) {
137*0fca6ea1SDimitry Andric   SPIRV::StorageClass::StorageClass SC =
138*0fca6ea1SDimitry Andric       static_cast<SPIRV::StorageClass::StorageClass>(
139*0fca6ea1SDimitry Andric           OpType->getOperand(1).getImm());
140*0fca6ea1SDimitry Andric   MachineIRBuilder MIB(I);
141*0fca6ea1SDimitry Andric   SPIRVType *NewBaseType =
142*0fca6ea1SDimitry Andric       ReuseType ? ResType
143*0fca6ea1SDimitry Andric                 : GR.getOrCreateSPIRVType(
144*0fca6ea1SDimitry Andric                       ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
145*0fca6ea1SDimitry Andric   return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
146*0fca6ea1SDimitry Andric }
147*0fca6ea1SDimitry Andric 
148*0fca6ea1SDimitry Andric // Insert a bitcast before the instruction to keep SPIR-V code valid
149*0fca6ea1SDimitry Andric // when there is a type mismatch between results and operand types.
150*0fca6ea1SDimitry Andric static void validatePtrTypes(const SPIRVSubtarget &STI,
151*0fca6ea1SDimitry Andric                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
152*0fca6ea1SDimitry Andric                              MachineInstr &I, unsigned OpIdx,
153*0fca6ea1SDimitry Andric                              SPIRVType *ResType, const Type *ResTy = nullptr) {
154*0fca6ea1SDimitry Andric   // Get operand type
155*0fca6ea1SDimitry Andric   MachineFunction *MF = I.getParent()->getParent();
156*0fca6ea1SDimitry Andric   Register OpReg = I.getOperand(OpIdx).getReg();
157*0fca6ea1SDimitry Andric   Register OpTypeReg = getTypeReg(MRI, OpReg);
158*0fca6ea1SDimitry Andric   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
159*0fca6ea1SDimitry Andric   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
160*0fca6ea1SDimitry Andric     return;
161*0fca6ea1SDimitry Andric   // Get operand's pointee type
162*0fca6ea1SDimitry Andric   Register ElemTypeReg = OpType->getOperand(2).getReg();
163*0fca6ea1SDimitry Andric   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
164*0fca6ea1SDimitry Andric   if (!ElemType)
165*0fca6ea1SDimitry Andric     return;
166*0fca6ea1SDimitry Andric   // Check if we need a bitcast to make a statement valid
167*0fca6ea1SDimitry Andric   bool IsSameMF = MF == ResType->getParent()->getParent();
168*0fca6ea1SDimitry Andric   bool IsEqualTypes = IsSameMF ? ElemType == ResType
169*0fca6ea1SDimitry Andric                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
170*0fca6ea1SDimitry Andric   if (IsEqualTypes)
171*0fca6ea1SDimitry Andric     return;
172*0fca6ea1SDimitry Andric   // There is a type mismatch between results and operand types
173*0fca6ea1SDimitry Andric   // and we insert a bitcast before the instruction to keep SPIR-V code valid
174*0fca6ea1SDimitry Andric   SPIRVType *NewPtrType =
175*0fca6ea1SDimitry Andric       createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
176*0fca6ea1SDimitry Andric   if (!GR.isBitcastCompatible(NewPtrType, OpType))
177*0fca6ea1SDimitry Andric     report_fatal_error(
178*0fca6ea1SDimitry Andric         "insert validation bitcast: incompatible result and operand types");
179*0fca6ea1SDimitry Andric   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
180*0fca6ea1SDimitry Andric }
181*0fca6ea1SDimitry Andric 
182*0fca6ea1SDimitry Andric // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
183*0fca6ea1SDimitry Andric // that doesn't point to OpTypeEvent.
184*0fca6ea1SDimitry Andric static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
185*0fca6ea1SDimitry Andric                                        MachineRegisterInfo *MRI,
186*0fca6ea1SDimitry Andric                                        SPIRVGlobalRegistry &GR,
187*0fca6ea1SDimitry Andric                                        MachineInstr &I) {
188*0fca6ea1SDimitry Andric   constexpr unsigned OpIdx = 2;
189*0fca6ea1SDimitry Andric   MachineFunction *MF = I.getParent()->getParent();
190*0fca6ea1SDimitry Andric   Register OpReg = I.getOperand(OpIdx).getReg();
191*0fca6ea1SDimitry Andric   Register OpTypeReg = getTypeReg(MRI, OpReg);
192*0fca6ea1SDimitry Andric   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
193*0fca6ea1SDimitry Andric   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
194*0fca6ea1SDimitry Andric     return;
195*0fca6ea1SDimitry Andric   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
196*0fca6ea1SDimitry Andric   if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
197*0fca6ea1SDimitry Andric     return;
198*0fca6ea1SDimitry Andric   // Insert a bitcast before the instruction to keep SPIR-V code valid.
199*0fca6ea1SDimitry Andric   LLVMContext &Context = MF->getFunction().getContext();
200*0fca6ea1SDimitry Andric   SPIRVType *NewPtrType =
201*0fca6ea1SDimitry Andric       createNewPtrType(GR, I, OpType, false, true, nullptr,
202*0fca6ea1SDimitry Andric                        TargetExtType::get(Context, "spirv.Event"));
203*0fca6ea1SDimitry Andric   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
204*0fca6ea1SDimitry Andric }
205*0fca6ea1SDimitry Andric 
206*0fca6ea1SDimitry Andric static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
207*0fca6ea1SDimitry Andric                                       MachineRegisterInfo *MRI,
208*0fca6ea1SDimitry Andric                                       SPIRVGlobalRegistry &GR, MachineInstr &I,
209*0fca6ea1SDimitry Andric                                       unsigned OpIdx) {
210*0fca6ea1SDimitry Andric   MachineFunction *MF = I.getParent()->getParent();
211*0fca6ea1SDimitry Andric   Register OpReg = I.getOperand(OpIdx).getReg();
212*0fca6ea1SDimitry Andric   Register OpTypeReg = getTypeReg(MRI, OpReg);
213*0fca6ea1SDimitry Andric   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
214*0fca6ea1SDimitry Andric   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
215*0fca6ea1SDimitry Andric     return;
216*0fca6ea1SDimitry Andric   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
217*0fca6ea1SDimitry Andric   if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
218*0fca6ea1SDimitry Andric       ElemType->getNumOperands() != 2)
219*0fca6ea1SDimitry Andric     return;
220*0fca6ea1SDimitry Andric   // It's a structure-wrapper around another type with a single member field.
221*0fca6ea1SDimitry Andric   SPIRVType *MemberType =
222*0fca6ea1SDimitry Andric       GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
223*0fca6ea1SDimitry Andric   if (!MemberType)
224*0fca6ea1SDimitry Andric     return;
225*0fca6ea1SDimitry Andric   unsigned MemberTypeOp = MemberType->getOpcode();
226*0fca6ea1SDimitry Andric   if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
227*0fca6ea1SDimitry Andric       MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
228*0fca6ea1SDimitry Andric     return;
229*0fca6ea1SDimitry Andric   // It's a structure-wrapper around a valid type. Insert a bitcast before the
230*0fca6ea1SDimitry Andric   // instruction to keep SPIR-V code valid.
231*0fca6ea1SDimitry Andric   SPIRV::StorageClass::StorageClass SC =
232*0fca6ea1SDimitry Andric       static_cast<SPIRV::StorageClass::StorageClass>(
233*0fca6ea1SDimitry Andric           OpType->getOperand(1).getImm());
234*0fca6ea1SDimitry Andric   MachineIRBuilder MIB(I);
235*0fca6ea1SDimitry Andric   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
236*0fca6ea1SDimitry Andric   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
237*0fca6ea1SDimitry Andric }
238*0fca6ea1SDimitry Andric 
239*0fca6ea1SDimitry Andric // Insert a bitcast before the function call instruction to keep SPIR-V code
240*0fca6ea1SDimitry Andric // valid when there is a type mismatch between actual and expected types of an
241*0fca6ea1SDimitry Andric // argument:
242*0fca6ea1SDimitry Andric // %formal = OpFunctionParameter %formal_type
243*0fca6ea1SDimitry Andric // ...
244*0fca6ea1SDimitry Andric // %res = OpFunctionCall %ty %fun %actual ...
245*0fca6ea1SDimitry Andric // implies that %actual is of %formal_type, and in case of opaque pointers.
246*0fca6ea1SDimitry Andric // We may need to insert a bitcast to ensure this.
247*0fca6ea1SDimitry Andric void validateFunCallMachineDef(const SPIRVSubtarget &STI,
248*0fca6ea1SDimitry Andric                                MachineRegisterInfo *DefMRI,
249*0fca6ea1SDimitry Andric                                MachineRegisterInfo *CallMRI,
250*0fca6ea1SDimitry Andric                                SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
251*0fca6ea1SDimitry Andric                                MachineInstr *FunDef) {
252*0fca6ea1SDimitry Andric   if (FunDef->getOpcode() != SPIRV::OpFunction)
253*0fca6ea1SDimitry Andric     return;
254*0fca6ea1SDimitry Andric   unsigned OpIdx = 3;
255*0fca6ea1SDimitry Andric   for (FunDef = FunDef->getNextNode();
256*0fca6ea1SDimitry Andric        FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
257*0fca6ea1SDimitry Andric        OpIdx < FunCall.getNumOperands();
258*0fca6ea1SDimitry Andric        FunDef = FunDef->getNextNode(), OpIdx++) {
259*0fca6ea1SDimitry Andric     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
260*0fca6ea1SDimitry Andric     SPIRVType *DefElemType =
261*0fca6ea1SDimitry Andric         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
262*0fca6ea1SDimitry Andric             ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
263*0fca6ea1SDimitry Andric                                      DefPtrType->getParent()->getParent())
264*0fca6ea1SDimitry Andric             : nullptr;
265*0fca6ea1SDimitry Andric     if (DefElemType) {
266*0fca6ea1SDimitry Andric       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
267*0fca6ea1SDimitry Andric       // validatePtrTypes() works in the context if the call site
268*0fca6ea1SDimitry Andric       // When we process historical records about forward calls
269*0fca6ea1SDimitry Andric       // we need to switch context to the (forward) call site and
270*0fca6ea1SDimitry Andric       // then restore it back to the current machine function.
271*0fca6ea1SDimitry Andric       MachineFunction *CurMF =
272*0fca6ea1SDimitry Andric           GR.setCurrentFunc(*FunCall.getParent()->getParent());
273*0fca6ea1SDimitry Andric       validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
274*0fca6ea1SDimitry Andric                        DefElemTy);
275*0fca6ea1SDimitry Andric       GR.setCurrentFunc(*CurMF);
276*0fca6ea1SDimitry Andric     }
277*0fca6ea1SDimitry Andric   }
278*0fca6ea1SDimitry Andric }
279*0fca6ea1SDimitry Andric 
280*0fca6ea1SDimitry Andric // Ensure there is no mismatch between actual and expected arg types: calls
281*0fca6ea1SDimitry Andric // with a processed definition. Return Function pointer if it's a forward
282*0fca6ea1SDimitry Andric // call (ahead of definition), and nullptr otherwise.
283*0fca6ea1SDimitry Andric const Function *validateFunCall(const SPIRVSubtarget &STI,
284*0fca6ea1SDimitry Andric                                 MachineRegisterInfo *CallMRI,
285*0fca6ea1SDimitry Andric                                 SPIRVGlobalRegistry &GR,
286*0fca6ea1SDimitry Andric                                 MachineInstr &FunCall) {
287*0fca6ea1SDimitry Andric   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
288*0fca6ea1SDimitry Andric   const Function *F = dyn_cast<Function>(GV);
289*0fca6ea1SDimitry Andric   MachineInstr *FunDef =
290*0fca6ea1SDimitry Andric       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
291*0fca6ea1SDimitry Andric   if (!FunDef)
292*0fca6ea1SDimitry Andric     return F;
293*0fca6ea1SDimitry Andric   MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
294*0fca6ea1SDimitry Andric   validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
295*0fca6ea1SDimitry Andric   return nullptr;
296*0fca6ea1SDimitry Andric }
297*0fca6ea1SDimitry Andric 
298*0fca6ea1SDimitry Andric // Ensure there is no mismatch between actual and expected arg types: calls
299*0fca6ea1SDimitry Andric // ahead of a processed definition.
300*0fca6ea1SDimitry Andric void validateForwardCalls(const SPIRVSubtarget &STI,
301*0fca6ea1SDimitry Andric                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
302*0fca6ea1SDimitry Andric                           MachineInstr &FunDef) {
303*0fca6ea1SDimitry Andric   const Function *F = GR.getFunctionByDefinition(&FunDef);
304*0fca6ea1SDimitry Andric   if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
305*0fca6ea1SDimitry Andric     for (MachineInstr *FunCall : *FwdCalls) {
306*0fca6ea1SDimitry Andric       MachineRegisterInfo *CallMRI =
307*0fca6ea1SDimitry Andric           &FunCall->getParent()->getParent()->getRegInfo();
308*0fca6ea1SDimitry Andric       validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
309*0fca6ea1SDimitry Andric     }
310*0fca6ea1SDimitry Andric }
311*0fca6ea1SDimitry Andric 
312*0fca6ea1SDimitry Andric // Validation of an access chain.
313*0fca6ea1SDimitry Andric void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
314*0fca6ea1SDimitry Andric                          SPIRVGlobalRegistry &GR, MachineInstr &I) {
315*0fca6ea1SDimitry Andric   SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
316*0fca6ea1SDimitry Andric   if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
317*0fca6ea1SDimitry Andric     SPIRVType *BaseElemType =
318*0fca6ea1SDimitry Andric         GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
319*0fca6ea1SDimitry Andric     validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
320*0fca6ea1SDimitry Andric   }
321*0fca6ea1SDimitry Andric }
322*0fca6ea1SDimitry Andric 
323*0fca6ea1SDimitry Andric // TODO: the logic of inserting additional bitcast's is to be moved
324*0fca6ea1SDimitry Andric // to pre-IRTranslation passes eventually
325*0fca6ea1SDimitry Andric void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
326*0fca6ea1SDimitry Andric   // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
327*0fca6ea1SDimitry Andric   // We'd like to avoid the needless second processing pass.
328*0fca6ea1SDimitry Andric   if (ProcessedMF.find(&MF) != ProcessedMF.end())
329*0fca6ea1SDimitry Andric     return;
330*0fca6ea1SDimitry Andric 
331*0fca6ea1SDimitry Andric   MachineRegisterInfo *MRI = &MF.getRegInfo();
332*0fca6ea1SDimitry Andric   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
333*0fca6ea1SDimitry Andric   GR.setCurrentFunc(MF);
334*0fca6ea1SDimitry Andric   for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
335*0fca6ea1SDimitry Andric     MachineBasicBlock *MBB = &*I;
336*0fca6ea1SDimitry Andric     for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
337*0fca6ea1SDimitry Andric          MBBI != MBBE;) {
338*0fca6ea1SDimitry Andric       MachineInstr &MI = *MBBI++;
339*0fca6ea1SDimitry Andric       switch (MI.getOpcode()) {
340*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicLoad:
341*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicExchange:
342*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicCompareExchange:
343*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicCompareExchangeWeak:
344*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicIIncrement:
345*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicIDecrement:
346*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicIAdd:
347*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicISub:
348*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicSMin:
349*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicUMin:
350*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicSMax:
351*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicUMax:
352*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicAnd:
353*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicOr:
354*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicXor:
355*0fca6ea1SDimitry Andric         // for the above listed instructions
356*0fca6ea1SDimitry Andric         // OpAtomicXXX <ResType>, ptr %Op, ...
357*0fca6ea1SDimitry Andric         // implies that %Op is a pointer to <ResType>
358*0fca6ea1SDimitry Andric       case SPIRV::OpLoad:
359*0fca6ea1SDimitry Andric         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
360*0fca6ea1SDimitry Andric         validatePtrTypes(STI, MRI, GR, MI, 2,
361*0fca6ea1SDimitry Andric                          GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362*0fca6ea1SDimitry Andric         break;
363*0fca6ea1SDimitry Andric       case SPIRV::OpAtomicStore:
364*0fca6ea1SDimitry Andric         // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
365*0fca6ea1SDimitry Andric         // implies that %Op points to the <Obj>'s type
366*0fca6ea1SDimitry Andric         validatePtrTypes(STI, MRI, GR, MI, 0,
367*0fca6ea1SDimitry Andric                          GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
368*0fca6ea1SDimitry Andric         break;
369*0fca6ea1SDimitry Andric       case SPIRV::OpStore:
370*0fca6ea1SDimitry Andric         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
371*0fca6ea1SDimitry Andric         validatePtrTypes(STI, MRI, GR, MI, 0,
372*0fca6ea1SDimitry Andric                          GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
373*0fca6ea1SDimitry Andric         break;
374*0fca6ea1SDimitry Andric       case SPIRV::OpPtrCastToGeneric:
375*0fca6ea1SDimitry Andric       case SPIRV::OpGenericCastToPtr:
376*0fca6ea1SDimitry Andric         validateAccessChain(STI, MRI, GR, MI);
377*0fca6ea1SDimitry Andric         break;
378*0fca6ea1SDimitry Andric       case SPIRV::OpInBoundsPtrAccessChain:
379*0fca6ea1SDimitry Andric         if (MI.getNumOperands() == 4)
380*0fca6ea1SDimitry Andric           validateAccessChain(STI, MRI, GR, MI);
381*0fca6ea1SDimitry Andric         break;
382*0fca6ea1SDimitry Andric 
383*0fca6ea1SDimitry Andric       case SPIRV::OpFunctionCall:
384*0fca6ea1SDimitry Andric         // ensure there is no mismatch between actual and expected arg types:
385*0fca6ea1SDimitry Andric         // calls with a processed definition
386*0fca6ea1SDimitry Andric         if (MI.getNumOperands() > 3)
387*0fca6ea1SDimitry Andric           if (const Function *F = validateFunCall(STI, MRI, GR, MI))
388*0fca6ea1SDimitry Andric             GR.addForwardCall(F, &MI);
389*0fca6ea1SDimitry Andric         break;
390*0fca6ea1SDimitry Andric       case SPIRV::OpFunction:
391*0fca6ea1SDimitry Andric         // ensure there is no mismatch between actual and expected arg types:
392*0fca6ea1SDimitry Andric         // calls ahead of a processed definition
393*0fca6ea1SDimitry Andric         validateForwardCalls(STI, MRI, GR, MI);
394*0fca6ea1SDimitry Andric         break;
395*0fca6ea1SDimitry Andric 
396*0fca6ea1SDimitry Andric       // ensure that LLVM IR bitwise instructions result in logical SPIR-V
397*0fca6ea1SDimitry Andric       // instructions when applied to bool type
398*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseOrS:
399*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseOrV:
400*0fca6ea1SDimitry Andric         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
401*0fca6ea1SDimitry Andric                                       SPIRV::OpTypeBool))
402*0fca6ea1SDimitry Andric           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
403*0fca6ea1SDimitry Andric         break;
404*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseAndS:
405*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseAndV:
406*0fca6ea1SDimitry Andric         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
407*0fca6ea1SDimitry Andric                                       SPIRV::OpTypeBool))
408*0fca6ea1SDimitry Andric           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
409*0fca6ea1SDimitry Andric         break;
410*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseXorS:
411*0fca6ea1SDimitry Andric       case SPIRV::OpBitwiseXorV:
412*0fca6ea1SDimitry Andric         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
413*0fca6ea1SDimitry Andric                                       SPIRV::OpTypeBool))
414*0fca6ea1SDimitry Andric           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
415*0fca6ea1SDimitry Andric         break;
416*0fca6ea1SDimitry Andric       case SPIRV::OpGroupAsyncCopy:
417*0fca6ea1SDimitry Andric         validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
418*0fca6ea1SDimitry Andric         validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
419*0fca6ea1SDimitry Andric         break;
420*0fca6ea1SDimitry Andric       case SPIRV::OpGroupWaitEvents:
421*0fca6ea1SDimitry Andric         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
422*0fca6ea1SDimitry Andric         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
423*0fca6ea1SDimitry Andric         break;
424*0fca6ea1SDimitry Andric       case SPIRV::OpConstantI: {
425*0fca6ea1SDimitry Andric         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
426*0fca6ea1SDimitry Andric         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
427*0fca6ea1SDimitry Andric             MI.getOperand(2).getImm() == 0) {
428*0fca6ea1SDimitry Andric           // Validate the null constant of a target extension type
429*0fca6ea1SDimitry Andric           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
430*0fca6ea1SDimitry Andric           for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
431*0fca6ea1SDimitry Andric             MI.removeOperand(i);
432*0fca6ea1SDimitry Andric         }
433*0fca6ea1SDimitry Andric       } break;
434*0fca6ea1SDimitry Andric       }
435*0fca6ea1SDimitry Andric     }
436*0fca6ea1SDimitry Andric   }
437*0fca6ea1SDimitry Andric   ProcessedMF.insert(&MF);
438*0fca6ea1SDimitry Andric   TargetLowering::finalizeLowering(MF);
439*0fca6ea1SDimitry Andric }
440