181ad6265SDimitry Andric //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 281ad6265SDimitry Andric // 381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 681ad6265SDimitry Andric // 781ad6265SDimitry Andric //===----------------------------------------------------------------------===// 881ad6265SDimitry Andric // 981ad6265SDimitry Andric // This file implements the SPIRVTargetLowering class. 1081ad6265SDimitry Andric // 1181ad6265SDimitry Andric //===----------------------------------------------------------------------===// 1281ad6265SDimitry Andric 1381ad6265SDimitry Andric #include "SPIRVISelLowering.h" 1481ad6265SDimitry Andric #include "SPIRV.h" 15*0fca6ea1SDimitry Andric #include "SPIRVInstrInfo.h" 16*0fca6ea1SDimitry Andric #include "SPIRVRegisterBankInfo.h" 17*0fca6ea1SDimitry Andric #include "SPIRVRegisterInfo.h" 18*0fca6ea1SDimitry Andric #include "SPIRVSubtarget.h" 19*0fca6ea1SDimitry Andric #include "SPIRVTargetMachine.h" 20*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineInstrBuilder.h" 21*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h" 22bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h" 2381ad6265SDimitry Andric 2481ad6265SDimitry Andric #define DEBUG_TYPE "spirv-lower" 2581ad6265SDimitry Andric 2681ad6265SDimitry Andric using namespace llvm; 2781ad6265SDimitry Andric 2881ad6265SDimitry Andric unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( 2981ad6265SDimitry Andric LLVMContext &Context, CallingConv::ID CC, EVT VT) const { 3081ad6265SDimitry Andric // This code avoids CallLowering fail inside getVectorTypeBreakdown 3181ad6265SDimitry Andric // on v3i1 arguments. Maybe we need to return 1 for all types. 3281ad6265SDimitry Andric // TODO: remove it once this case is supported by the default implementation. 3381ad6265SDimitry Andric if (VT.isVector() && VT.getVectorNumElements() == 3 && 3481ad6265SDimitry Andric (VT.getVectorElementType() == MVT::i1 || 3581ad6265SDimitry Andric VT.getVectorElementType() == MVT::i8)) 3681ad6265SDimitry Andric return 1; 3706c3fb27SDimitry Andric if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) 3806c3fb27SDimitry Andric return 1; 3981ad6265SDimitry Andric return getNumRegisters(Context, VT); 4081ad6265SDimitry Andric } 4181ad6265SDimitry Andric 4281ad6265SDimitry Andric MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, 4381ad6265SDimitry Andric CallingConv::ID CC, 4481ad6265SDimitry Andric EVT VT) const { 4581ad6265SDimitry Andric // This code avoids CallLowering fail inside getVectorTypeBreakdown 4681ad6265SDimitry Andric // on v3i1 arguments. Maybe we need to return i32 for all types. 4781ad6265SDimitry Andric // TODO: remove it once this case is supported by the default implementation. 4881ad6265SDimitry Andric if (VT.isVector() && VT.getVectorNumElements() == 3) { 4981ad6265SDimitry Andric if (VT.getVectorElementType() == MVT::i1) 5081ad6265SDimitry Andric return MVT::v4i1; 5181ad6265SDimitry Andric else if (VT.getVectorElementType() == MVT::i8) 5281ad6265SDimitry Andric return MVT::v4i8; 5381ad6265SDimitry Andric } 5481ad6265SDimitry Andric return getRegisterType(Context, VT); 5581ad6265SDimitry Andric } 56bdd1243dSDimitry Andric 57bdd1243dSDimitry Andric bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, 58bdd1243dSDimitry Andric const CallInst &I, 59bdd1243dSDimitry Andric MachineFunction &MF, 60bdd1243dSDimitry Andric unsigned Intrinsic) const { 61bdd1243dSDimitry Andric unsigned AlignIdx = 3; 62bdd1243dSDimitry Andric switch (Intrinsic) { 63bdd1243dSDimitry Andric case Intrinsic::spv_load: 64bdd1243dSDimitry Andric AlignIdx = 2; 6506c3fb27SDimitry Andric [[fallthrough]]; 66bdd1243dSDimitry Andric case Intrinsic::spv_store: { 67bdd1243dSDimitry Andric if (I.getNumOperands() >= AlignIdx + 1) { 68bdd1243dSDimitry Andric auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx)); 69bdd1243dSDimitry Andric Info.align = Align(AlignOp->getZExtValue()); 70bdd1243dSDimitry Andric } 71bdd1243dSDimitry Andric Info.flags = static_cast<MachineMemOperand::Flags>( 72bdd1243dSDimitry Andric cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue()); 73bdd1243dSDimitry Andric Info.memVT = MVT::i64; 74bdd1243dSDimitry Andric // TODO: take into account opaque pointers (don't use getElementType). 75bdd1243dSDimitry Andric // MVT::getVT(PtrTy->getElementType()); 76bdd1243dSDimitry Andric return true; 77bdd1243dSDimitry Andric break; 78bdd1243dSDimitry Andric } 79bdd1243dSDimitry Andric default: 80bdd1243dSDimitry Andric break; 81bdd1243dSDimitry Andric } 82bdd1243dSDimitry Andric return false; 83bdd1243dSDimitry Andric } 84*0fca6ea1SDimitry Andric 85*0fca6ea1SDimitry Andric std::pair<unsigned, const TargetRegisterClass *> 86*0fca6ea1SDimitry Andric SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, 87*0fca6ea1SDimitry Andric StringRef Constraint, 88*0fca6ea1SDimitry Andric MVT VT) const { 89*0fca6ea1SDimitry Andric const TargetRegisterClass *RC = nullptr; 90*0fca6ea1SDimitry Andric if (Constraint.starts_with("{")) 91*0fca6ea1SDimitry Andric return std::make_pair(0u, RC); 92*0fca6ea1SDimitry Andric 93*0fca6ea1SDimitry Andric if (VT.isFloatingPoint()) 94*0fca6ea1SDimitry Andric RC = VT.isVector() ? &SPIRV::vfIDRegClass 95*0fca6ea1SDimitry Andric : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass 96*0fca6ea1SDimitry Andric : &SPIRV::fIDRegClass); 97*0fca6ea1SDimitry Andric else if (VT.isInteger()) 98*0fca6ea1SDimitry Andric RC = VT.isVector() ? &SPIRV::vIDRegClass 99*0fca6ea1SDimitry Andric : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass 100*0fca6ea1SDimitry Andric : &SPIRV::IDRegClass); 101*0fca6ea1SDimitry Andric else 102*0fca6ea1SDimitry Andric RC = &SPIRV::IDRegClass; 103*0fca6ea1SDimitry Andric 104*0fca6ea1SDimitry Andric return std::make_pair(0u, RC); 105*0fca6ea1SDimitry Andric } 106*0fca6ea1SDimitry Andric 107*0fca6ea1SDimitry Andric inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) { 108*0fca6ea1SDimitry Andric SPIRVType *TypeInst = MRI->getVRegDef(OpReg); 109*0fca6ea1SDimitry Andric return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter 110*0fca6ea1SDimitry Andric ? TypeInst->getOperand(1).getReg() 111*0fca6ea1SDimitry Andric : OpReg; 112*0fca6ea1SDimitry Andric } 113*0fca6ea1SDimitry Andric 114*0fca6ea1SDimitry Andric static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 115*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, MachineInstr &I, 116*0fca6ea1SDimitry Andric Register OpReg, unsigned OpIdx, 117*0fca6ea1SDimitry Andric SPIRVType *NewPtrType) { 118*0fca6ea1SDimitry Andric Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 119*0fca6ea1SDimitry Andric MachineIRBuilder MIB(I); 120*0fca6ea1SDimitry Andric bool Res = MIB.buildInstr(SPIRV::OpBitcast) 121*0fca6ea1SDimitry Andric .addDef(NewReg) 122*0fca6ea1SDimitry Andric .addUse(GR.getSPIRVTypeID(NewPtrType)) 123*0fca6ea1SDimitry Andric .addUse(OpReg) 124*0fca6ea1SDimitry Andric .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), 125*0fca6ea1SDimitry Andric *STI.getRegBankInfo()); 126*0fca6ea1SDimitry Andric if (!Res) 127*0fca6ea1SDimitry Andric report_fatal_error("insert validation bitcast: cannot constrain all uses"); 128*0fca6ea1SDimitry Andric MRI->setRegClass(NewReg, &SPIRV::IDRegClass); 129*0fca6ea1SDimitry Andric GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); 130*0fca6ea1SDimitry Andric I.getOperand(OpIdx).setReg(NewReg); 131*0fca6ea1SDimitry Andric } 132*0fca6ea1SDimitry Andric 133*0fca6ea1SDimitry Andric static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, 134*0fca6ea1SDimitry Andric SPIRVType *OpType, bool ReuseType, 135*0fca6ea1SDimitry Andric bool EmitIR, SPIRVType *ResType, 136*0fca6ea1SDimitry Andric const Type *ResTy) { 137*0fca6ea1SDimitry Andric SPIRV::StorageClass::StorageClass SC = 138*0fca6ea1SDimitry Andric static_cast<SPIRV::StorageClass::StorageClass>( 139*0fca6ea1SDimitry Andric OpType->getOperand(1).getImm()); 140*0fca6ea1SDimitry Andric MachineIRBuilder MIB(I); 141*0fca6ea1SDimitry Andric SPIRVType *NewBaseType = 142*0fca6ea1SDimitry Andric ReuseType ? ResType 143*0fca6ea1SDimitry Andric : GR.getOrCreateSPIRVType( 144*0fca6ea1SDimitry Andric ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR); 145*0fca6ea1SDimitry Andric return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); 146*0fca6ea1SDimitry Andric } 147*0fca6ea1SDimitry Andric 148*0fca6ea1SDimitry Andric // Insert a bitcast before the instruction to keep SPIR-V code valid 149*0fca6ea1SDimitry Andric // when there is a type mismatch between results and operand types. 150*0fca6ea1SDimitry Andric static void validatePtrTypes(const SPIRVSubtarget &STI, 151*0fca6ea1SDimitry Andric MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, 152*0fca6ea1SDimitry Andric MachineInstr &I, unsigned OpIdx, 153*0fca6ea1SDimitry Andric SPIRVType *ResType, const Type *ResTy = nullptr) { 154*0fca6ea1SDimitry Andric // Get operand type 155*0fca6ea1SDimitry Andric MachineFunction *MF = I.getParent()->getParent(); 156*0fca6ea1SDimitry Andric Register OpReg = I.getOperand(OpIdx).getReg(); 157*0fca6ea1SDimitry Andric Register OpTypeReg = getTypeReg(MRI, OpReg); 158*0fca6ea1SDimitry Andric SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 159*0fca6ea1SDimitry Andric if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 160*0fca6ea1SDimitry Andric return; 161*0fca6ea1SDimitry Andric // Get operand's pointee type 162*0fca6ea1SDimitry Andric Register ElemTypeReg = OpType->getOperand(2).getReg(); 163*0fca6ea1SDimitry Andric SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF); 164*0fca6ea1SDimitry Andric if (!ElemType) 165*0fca6ea1SDimitry Andric return; 166*0fca6ea1SDimitry Andric // Check if we need a bitcast to make a statement valid 167*0fca6ea1SDimitry Andric bool IsSameMF = MF == ResType->getParent()->getParent(); 168*0fca6ea1SDimitry Andric bool IsEqualTypes = IsSameMF ? ElemType == ResType 169*0fca6ea1SDimitry Andric : GR.getTypeForSPIRVType(ElemType) == ResTy; 170*0fca6ea1SDimitry Andric if (IsEqualTypes) 171*0fca6ea1SDimitry Andric return; 172*0fca6ea1SDimitry Andric // There is a type mismatch between results and operand types 173*0fca6ea1SDimitry Andric // and we insert a bitcast before the instruction to keep SPIR-V code valid 174*0fca6ea1SDimitry Andric SPIRVType *NewPtrType = 175*0fca6ea1SDimitry Andric createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy); 176*0fca6ea1SDimitry Andric if (!GR.isBitcastCompatible(NewPtrType, OpType)) 177*0fca6ea1SDimitry Andric report_fatal_error( 178*0fca6ea1SDimitry Andric "insert validation bitcast: incompatible result and operand types"); 179*0fca6ea1SDimitry Andric doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 180*0fca6ea1SDimitry Andric } 181*0fca6ea1SDimitry Andric 182*0fca6ea1SDimitry Andric // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer 183*0fca6ea1SDimitry Andric // that doesn't point to OpTypeEvent. 184*0fca6ea1SDimitry Andric static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, 185*0fca6ea1SDimitry Andric MachineRegisterInfo *MRI, 186*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, 187*0fca6ea1SDimitry Andric MachineInstr &I) { 188*0fca6ea1SDimitry Andric constexpr unsigned OpIdx = 2; 189*0fca6ea1SDimitry Andric MachineFunction *MF = I.getParent()->getParent(); 190*0fca6ea1SDimitry Andric Register OpReg = I.getOperand(OpIdx).getReg(); 191*0fca6ea1SDimitry Andric Register OpTypeReg = getTypeReg(MRI, OpReg); 192*0fca6ea1SDimitry Andric SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 193*0fca6ea1SDimitry Andric if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 194*0fca6ea1SDimitry Andric return; 195*0fca6ea1SDimitry Andric SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 196*0fca6ea1SDimitry Andric if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent) 197*0fca6ea1SDimitry Andric return; 198*0fca6ea1SDimitry Andric // Insert a bitcast before the instruction to keep SPIR-V code valid. 199*0fca6ea1SDimitry Andric LLVMContext &Context = MF->getFunction().getContext(); 200*0fca6ea1SDimitry Andric SPIRVType *NewPtrType = 201*0fca6ea1SDimitry Andric createNewPtrType(GR, I, OpType, false, true, nullptr, 202*0fca6ea1SDimitry Andric TargetExtType::get(Context, "spirv.Event")); 203*0fca6ea1SDimitry Andric doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 204*0fca6ea1SDimitry Andric } 205*0fca6ea1SDimitry Andric 206*0fca6ea1SDimitry Andric static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI, 207*0fca6ea1SDimitry Andric MachineRegisterInfo *MRI, 208*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, MachineInstr &I, 209*0fca6ea1SDimitry Andric unsigned OpIdx) { 210*0fca6ea1SDimitry Andric MachineFunction *MF = I.getParent()->getParent(); 211*0fca6ea1SDimitry Andric Register OpReg = I.getOperand(OpIdx).getReg(); 212*0fca6ea1SDimitry Andric Register OpTypeReg = getTypeReg(MRI, OpReg); 213*0fca6ea1SDimitry Andric SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 214*0fca6ea1SDimitry Andric if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 215*0fca6ea1SDimitry Andric return; 216*0fca6ea1SDimitry Andric SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 217*0fca6ea1SDimitry Andric if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct || 218*0fca6ea1SDimitry Andric ElemType->getNumOperands() != 2) 219*0fca6ea1SDimitry Andric return; 220*0fca6ea1SDimitry Andric // It's a structure-wrapper around another type with a single member field. 221*0fca6ea1SDimitry Andric SPIRVType *MemberType = 222*0fca6ea1SDimitry Andric GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg()); 223*0fca6ea1SDimitry Andric if (!MemberType) 224*0fca6ea1SDimitry Andric return; 225*0fca6ea1SDimitry Andric unsigned MemberTypeOp = MemberType->getOpcode(); 226*0fca6ea1SDimitry Andric if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt && 227*0fca6ea1SDimitry Andric MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool) 228*0fca6ea1SDimitry Andric return; 229*0fca6ea1SDimitry Andric // It's a structure-wrapper around a valid type. Insert a bitcast before the 230*0fca6ea1SDimitry Andric // instruction to keep SPIR-V code valid. 231*0fca6ea1SDimitry Andric SPIRV::StorageClass::StorageClass SC = 232*0fca6ea1SDimitry Andric static_cast<SPIRV::StorageClass::StorageClass>( 233*0fca6ea1SDimitry Andric OpType->getOperand(1).getImm()); 234*0fca6ea1SDimitry Andric MachineIRBuilder MIB(I); 235*0fca6ea1SDimitry Andric SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC); 236*0fca6ea1SDimitry Andric doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 237*0fca6ea1SDimitry Andric } 238*0fca6ea1SDimitry Andric 239*0fca6ea1SDimitry Andric // Insert a bitcast before the function call instruction to keep SPIR-V code 240*0fca6ea1SDimitry Andric // valid when there is a type mismatch between actual and expected types of an 241*0fca6ea1SDimitry Andric // argument: 242*0fca6ea1SDimitry Andric // %formal = OpFunctionParameter %formal_type 243*0fca6ea1SDimitry Andric // ... 244*0fca6ea1SDimitry Andric // %res = OpFunctionCall %ty %fun %actual ... 245*0fca6ea1SDimitry Andric // implies that %actual is of %formal_type, and in case of opaque pointers. 246*0fca6ea1SDimitry Andric // We may need to insert a bitcast to ensure this. 247*0fca6ea1SDimitry Andric void validateFunCallMachineDef(const SPIRVSubtarget &STI, 248*0fca6ea1SDimitry Andric MachineRegisterInfo *DefMRI, 249*0fca6ea1SDimitry Andric MachineRegisterInfo *CallMRI, 250*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, MachineInstr &FunCall, 251*0fca6ea1SDimitry Andric MachineInstr *FunDef) { 252*0fca6ea1SDimitry Andric if (FunDef->getOpcode() != SPIRV::OpFunction) 253*0fca6ea1SDimitry Andric return; 254*0fca6ea1SDimitry Andric unsigned OpIdx = 3; 255*0fca6ea1SDimitry Andric for (FunDef = FunDef->getNextNode(); 256*0fca6ea1SDimitry Andric FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && 257*0fca6ea1SDimitry Andric OpIdx < FunCall.getNumOperands(); 258*0fca6ea1SDimitry Andric FunDef = FunDef->getNextNode(), OpIdx++) { 259*0fca6ea1SDimitry Andric SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); 260*0fca6ea1SDimitry Andric SPIRVType *DefElemType = 261*0fca6ea1SDimitry Andric DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer 262*0fca6ea1SDimitry Andric ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(), 263*0fca6ea1SDimitry Andric DefPtrType->getParent()->getParent()) 264*0fca6ea1SDimitry Andric : nullptr; 265*0fca6ea1SDimitry Andric if (DefElemType) { 266*0fca6ea1SDimitry Andric const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); 267*0fca6ea1SDimitry Andric // validatePtrTypes() works in the context if the call site 268*0fca6ea1SDimitry Andric // When we process historical records about forward calls 269*0fca6ea1SDimitry Andric // we need to switch context to the (forward) call site and 270*0fca6ea1SDimitry Andric // then restore it back to the current machine function. 271*0fca6ea1SDimitry Andric MachineFunction *CurMF = 272*0fca6ea1SDimitry Andric GR.setCurrentFunc(*FunCall.getParent()->getParent()); 273*0fca6ea1SDimitry Andric validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, 274*0fca6ea1SDimitry Andric DefElemTy); 275*0fca6ea1SDimitry Andric GR.setCurrentFunc(*CurMF); 276*0fca6ea1SDimitry Andric } 277*0fca6ea1SDimitry Andric } 278*0fca6ea1SDimitry Andric } 279*0fca6ea1SDimitry Andric 280*0fca6ea1SDimitry Andric // Ensure there is no mismatch between actual and expected arg types: calls 281*0fca6ea1SDimitry Andric // with a processed definition. Return Function pointer if it's a forward 282*0fca6ea1SDimitry Andric // call (ahead of definition), and nullptr otherwise. 283*0fca6ea1SDimitry Andric const Function *validateFunCall(const SPIRVSubtarget &STI, 284*0fca6ea1SDimitry Andric MachineRegisterInfo *CallMRI, 285*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, 286*0fca6ea1SDimitry Andric MachineInstr &FunCall) { 287*0fca6ea1SDimitry Andric const GlobalValue *GV = FunCall.getOperand(2).getGlobal(); 288*0fca6ea1SDimitry Andric const Function *F = dyn_cast<Function>(GV); 289*0fca6ea1SDimitry Andric MachineInstr *FunDef = 290*0fca6ea1SDimitry Andric const_cast<MachineInstr *>(GR.getFunctionDefinition(F)); 291*0fca6ea1SDimitry Andric if (!FunDef) 292*0fca6ea1SDimitry Andric return F; 293*0fca6ea1SDimitry Andric MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); 294*0fca6ea1SDimitry Andric validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); 295*0fca6ea1SDimitry Andric return nullptr; 296*0fca6ea1SDimitry Andric } 297*0fca6ea1SDimitry Andric 298*0fca6ea1SDimitry Andric // Ensure there is no mismatch between actual and expected arg types: calls 299*0fca6ea1SDimitry Andric // ahead of a processed definition. 300*0fca6ea1SDimitry Andric void validateForwardCalls(const SPIRVSubtarget &STI, 301*0fca6ea1SDimitry Andric MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, 302*0fca6ea1SDimitry Andric MachineInstr &FunDef) { 303*0fca6ea1SDimitry Andric const Function *F = GR.getFunctionByDefinition(&FunDef); 304*0fca6ea1SDimitry Andric if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F)) 305*0fca6ea1SDimitry Andric for (MachineInstr *FunCall : *FwdCalls) { 306*0fca6ea1SDimitry Andric MachineRegisterInfo *CallMRI = 307*0fca6ea1SDimitry Andric &FunCall->getParent()->getParent()->getRegInfo(); 308*0fca6ea1SDimitry Andric validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef); 309*0fca6ea1SDimitry Andric } 310*0fca6ea1SDimitry Andric } 311*0fca6ea1SDimitry Andric 312*0fca6ea1SDimitry Andric // Validation of an access chain. 313*0fca6ea1SDimitry Andric void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 314*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR, MachineInstr &I) { 315*0fca6ea1SDimitry Andric SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg()); 316*0fca6ea1SDimitry Andric if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) { 317*0fca6ea1SDimitry Andric SPIRVType *BaseElemType = 318*0fca6ea1SDimitry Andric GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg()); 319*0fca6ea1SDimitry Andric validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType); 320*0fca6ea1SDimitry Andric } 321*0fca6ea1SDimitry Andric } 322*0fca6ea1SDimitry Andric 323*0fca6ea1SDimitry Andric // TODO: the logic of inserting additional bitcast's is to be moved 324*0fca6ea1SDimitry Andric // to pre-IRTranslation passes eventually 325*0fca6ea1SDimitry Andric void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { 326*0fca6ea1SDimitry Andric // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) 327*0fca6ea1SDimitry Andric // We'd like to avoid the needless second processing pass. 328*0fca6ea1SDimitry Andric if (ProcessedMF.find(&MF) != ProcessedMF.end()) 329*0fca6ea1SDimitry Andric return; 330*0fca6ea1SDimitry Andric 331*0fca6ea1SDimitry Andric MachineRegisterInfo *MRI = &MF.getRegInfo(); 332*0fca6ea1SDimitry Andric SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); 333*0fca6ea1SDimitry Andric GR.setCurrentFunc(MF); 334*0fca6ea1SDimitry Andric for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { 335*0fca6ea1SDimitry Andric MachineBasicBlock *MBB = &*I; 336*0fca6ea1SDimitry Andric for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); 337*0fca6ea1SDimitry Andric MBBI != MBBE;) { 338*0fca6ea1SDimitry Andric MachineInstr &MI = *MBBI++; 339*0fca6ea1SDimitry Andric switch (MI.getOpcode()) { 340*0fca6ea1SDimitry Andric case SPIRV::OpAtomicLoad: 341*0fca6ea1SDimitry Andric case SPIRV::OpAtomicExchange: 342*0fca6ea1SDimitry Andric case SPIRV::OpAtomicCompareExchange: 343*0fca6ea1SDimitry Andric case SPIRV::OpAtomicCompareExchangeWeak: 344*0fca6ea1SDimitry Andric case SPIRV::OpAtomicIIncrement: 345*0fca6ea1SDimitry Andric case SPIRV::OpAtomicIDecrement: 346*0fca6ea1SDimitry Andric case SPIRV::OpAtomicIAdd: 347*0fca6ea1SDimitry Andric case SPIRV::OpAtomicISub: 348*0fca6ea1SDimitry Andric case SPIRV::OpAtomicSMin: 349*0fca6ea1SDimitry Andric case SPIRV::OpAtomicUMin: 350*0fca6ea1SDimitry Andric case SPIRV::OpAtomicSMax: 351*0fca6ea1SDimitry Andric case SPIRV::OpAtomicUMax: 352*0fca6ea1SDimitry Andric case SPIRV::OpAtomicAnd: 353*0fca6ea1SDimitry Andric case SPIRV::OpAtomicOr: 354*0fca6ea1SDimitry Andric case SPIRV::OpAtomicXor: 355*0fca6ea1SDimitry Andric // for the above listed instructions 356*0fca6ea1SDimitry Andric // OpAtomicXXX <ResType>, ptr %Op, ... 357*0fca6ea1SDimitry Andric // implies that %Op is a pointer to <ResType> 358*0fca6ea1SDimitry Andric case SPIRV::OpLoad: 359*0fca6ea1SDimitry Andric // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> 360*0fca6ea1SDimitry Andric validatePtrTypes(STI, MRI, GR, MI, 2, 361*0fca6ea1SDimitry Andric GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 362*0fca6ea1SDimitry Andric break; 363*0fca6ea1SDimitry Andric case SPIRV::OpAtomicStore: 364*0fca6ea1SDimitry Andric // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> 365*0fca6ea1SDimitry Andric // implies that %Op points to the <Obj>'s type 366*0fca6ea1SDimitry Andric validatePtrTypes(STI, MRI, GR, MI, 0, 367*0fca6ea1SDimitry Andric GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg())); 368*0fca6ea1SDimitry Andric break; 369*0fca6ea1SDimitry Andric case SPIRV::OpStore: 370*0fca6ea1SDimitry Andric // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type 371*0fca6ea1SDimitry Andric validatePtrTypes(STI, MRI, GR, MI, 0, 372*0fca6ea1SDimitry Andric GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg())); 373*0fca6ea1SDimitry Andric break; 374*0fca6ea1SDimitry Andric case SPIRV::OpPtrCastToGeneric: 375*0fca6ea1SDimitry Andric case SPIRV::OpGenericCastToPtr: 376*0fca6ea1SDimitry Andric validateAccessChain(STI, MRI, GR, MI); 377*0fca6ea1SDimitry Andric break; 378*0fca6ea1SDimitry Andric case SPIRV::OpInBoundsPtrAccessChain: 379*0fca6ea1SDimitry Andric if (MI.getNumOperands() == 4) 380*0fca6ea1SDimitry Andric validateAccessChain(STI, MRI, GR, MI); 381*0fca6ea1SDimitry Andric break; 382*0fca6ea1SDimitry Andric 383*0fca6ea1SDimitry Andric case SPIRV::OpFunctionCall: 384*0fca6ea1SDimitry Andric // ensure there is no mismatch between actual and expected arg types: 385*0fca6ea1SDimitry Andric // calls with a processed definition 386*0fca6ea1SDimitry Andric if (MI.getNumOperands() > 3) 387*0fca6ea1SDimitry Andric if (const Function *F = validateFunCall(STI, MRI, GR, MI)) 388*0fca6ea1SDimitry Andric GR.addForwardCall(F, &MI); 389*0fca6ea1SDimitry Andric break; 390*0fca6ea1SDimitry Andric case SPIRV::OpFunction: 391*0fca6ea1SDimitry Andric // ensure there is no mismatch between actual and expected arg types: 392*0fca6ea1SDimitry Andric // calls ahead of a processed definition 393*0fca6ea1SDimitry Andric validateForwardCalls(STI, MRI, GR, MI); 394*0fca6ea1SDimitry Andric break; 395*0fca6ea1SDimitry Andric 396*0fca6ea1SDimitry Andric // ensure that LLVM IR bitwise instructions result in logical SPIR-V 397*0fca6ea1SDimitry Andric // instructions when applied to bool type 398*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseOrS: 399*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseOrV: 400*0fca6ea1SDimitry Andric if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 401*0fca6ea1SDimitry Andric SPIRV::OpTypeBool)) 402*0fca6ea1SDimitry Andric MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr)); 403*0fca6ea1SDimitry Andric break; 404*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseAndS: 405*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseAndV: 406*0fca6ea1SDimitry Andric if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 407*0fca6ea1SDimitry Andric SPIRV::OpTypeBool)) 408*0fca6ea1SDimitry Andric MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd)); 409*0fca6ea1SDimitry Andric break; 410*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseXorS: 411*0fca6ea1SDimitry Andric case SPIRV::OpBitwiseXorV: 412*0fca6ea1SDimitry Andric if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 413*0fca6ea1SDimitry Andric SPIRV::OpTypeBool)) 414*0fca6ea1SDimitry Andric MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); 415*0fca6ea1SDimitry Andric break; 416*0fca6ea1SDimitry Andric case SPIRV::OpGroupAsyncCopy: 417*0fca6ea1SDimitry Andric validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3); 418*0fca6ea1SDimitry Andric validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4); 419*0fca6ea1SDimitry Andric break; 420*0fca6ea1SDimitry Andric case SPIRV::OpGroupWaitEvents: 421*0fca6ea1SDimitry Andric // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> 422*0fca6ea1SDimitry Andric validateGroupWaitEventsPtr(STI, MRI, GR, MI); 423*0fca6ea1SDimitry Andric break; 424*0fca6ea1SDimitry Andric case SPIRV::OpConstantI: { 425*0fca6ea1SDimitry Andric SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); 426*0fca6ea1SDimitry Andric if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() && 427*0fca6ea1SDimitry Andric MI.getOperand(2).getImm() == 0) { 428*0fca6ea1SDimitry Andric // Validate the null constant of a target extension type 429*0fca6ea1SDimitry Andric MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull)); 430*0fca6ea1SDimitry Andric for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) 431*0fca6ea1SDimitry Andric MI.removeOperand(i); 432*0fca6ea1SDimitry Andric } 433*0fca6ea1SDimitry Andric } break; 434*0fca6ea1SDimitry Andric } 435*0fca6ea1SDimitry Andric } 436*0fca6ea1SDimitry Andric } 437*0fca6ea1SDimitry Andric ProcessedMF.insert(&MF); 438*0fca6ea1SDimitry Andric TargetLowering::finalizeLowering(MF); 439*0fca6ea1SDimitry Andric } 440