xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
181ad6265SDimitry Andric //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This file implements the SPIRVTargetLowering class.
1081ad6265SDimitry Andric //
1181ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1281ad6265SDimitry Andric 
1381ad6265SDimitry Andric #include "SPIRVISelLowering.h"
1481ad6265SDimitry Andric #include "SPIRV.h"
15*bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
1681ad6265SDimitry Andric 
1781ad6265SDimitry Andric #define DEBUG_TYPE "spirv-lower"
1881ad6265SDimitry Andric 
1981ad6265SDimitry Andric using namespace llvm;
2081ad6265SDimitry Andric 
2181ad6265SDimitry Andric unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
2281ad6265SDimitry Andric     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
2381ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
2481ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return 1 for all types.
2581ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
2681ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
2781ad6265SDimitry Andric       (VT.getVectorElementType() == MVT::i1 ||
2881ad6265SDimitry Andric        VT.getVectorElementType() == MVT::i8))
2981ad6265SDimitry Andric     return 1;
3081ad6265SDimitry Andric   return getNumRegisters(Context, VT);
3181ad6265SDimitry Andric }
3281ad6265SDimitry Andric 
3381ad6265SDimitry Andric MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
3481ad6265SDimitry Andric                                                        CallingConv::ID CC,
3581ad6265SDimitry Andric                                                        EVT VT) const {
3681ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
3781ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return i32 for all types.
3881ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
3981ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3) {
4081ad6265SDimitry Andric     if (VT.getVectorElementType() == MVT::i1)
4181ad6265SDimitry Andric       return MVT::v4i1;
4281ad6265SDimitry Andric     else if (VT.getVectorElementType() == MVT::i8)
4381ad6265SDimitry Andric       return MVT::v4i8;
4481ad6265SDimitry Andric   }
4581ad6265SDimitry Andric   return getRegisterType(Context, VT);
4681ad6265SDimitry Andric }
47*bdd1243dSDimitry Andric 
48*bdd1243dSDimitry Andric bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
49*bdd1243dSDimitry Andric                                              const CallInst &I,
50*bdd1243dSDimitry Andric                                              MachineFunction &MF,
51*bdd1243dSDimitry Andric                                              unsigned Intrinsic) const {
52*bdd1243dSDimitry Andric   unsigned AlignIdx = 3;
53*bdd1243dSDimitry Andric   switch (Intrinsic) {
54*bdd1243dSDimitry Andric   case Intrinsic::spv_load:
55*bdd1243dSDimitry Andric     AlignIdx = 2;
56*bdd1243dSDimitry Andric     LLVM_FALLTHROUGH;
57*bdd1243dSDimitry Andric   case Intrinsic::spv_store: {
58*bdd1243dSDimitry Andric     if (I.getNumOperands() >= AlignIdx + 1) {
59*bdd1243dSDimitry Andric       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
60*bdd1243dSDimitry Andric       Info.align = Align(AlignOp->getZExtValue());
61*bdd1243dSDimitry Andric     }
62*bdd1243dSDimitry Andric     Info.flags = static_cast<MachineMemOperand::Flags>(
63*bdd1243dSDimitry Andric         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
64*bdd1243dSDimitry Andric     Info.memVT = MVT::i64;
65*bdd1243dSDimitry Andric     // TODO: take into account opaque pointers (don't use getElementType).
66*bdd1243dSDimitry Andric     // MVT::getVT(PtrTy->getElementType());
67*bdd1243dSDimitry Andric     return true;
68*bdd1243dSDimitry Andric     break;
69*bdd1243dSDimitry Andric   }
70*bdd1243dSDimitry Andric   default:
71*bdd1243dSDimitry Andric     break;
72*bdd1243dSDimitry Andric   }
73*bdd1243dSDimitry Andric   return false;
74*bdd1243dSDimitry Andric }
75