xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (revision 978de2d6664a74864471d62244700c216fdc6741)
1 //===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
2 //
3 // which may appear after the legalizer pass
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // The pass partially apply pre-legalization logic to new instructions inserted
12 // as a result of legalization:
13 // - assigns SPIR-V types to registers for new instructions.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRV.h"
18 #include "SPIRVSubtarget.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/CodeGen/MachinePostDominators.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DebugInfoMetadata.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/Target/TargetIntrinsicInfo.h"
28 #include <stack>
29 
30 #define DEBUG_TYPE "spirv-postlegalizer"
31 
32 using namespace llvm;
33 
34 namespace {
35 class SPIRVPostLegalizer : public MachineFunctionPass {
36 public:
37   static char ID;
38   SPIRVPostLegalizer() : MachineFunctionPass(ID) {
39     initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
40   }
41   bool runOnMachineFunction(MachineFunction &MF) override;
42 };
43 } // namespace
44 
45 // Defined in SPIRVLegalizerInfo.cpp.
46 extern bool isTypeFoldingSupported(unsigned Opcode);
47 
48 namespace llvm {
49 //  Defined in SPIRVPreLegalizer.cpp.
50 extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
51                                   SPIRVGlobalRegistry *GR,
52                                   MachineIRBuilder &MIB,
53                                   MachineRegisterInfo &MRI);
54 extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
55                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
56 } // namespace llvm
57 
58 static bool mayBeInserted(unsigned Opcode) {
59   switch (Opcode) {
60   case TargetOpcode::G_SMAX:
61   case TargetOpcode::G_UMAX:
62   case TargetOpcode::G_SMIN:
63   case TargetOpcode::G_UMIN:
64   case TargetOpcode::G_FMINNUM:
65   case TargetOpcode::G_FMINIMUM:
66   case TargetOpcode::G_FMAXNUM:
67   case TargetOpcode::G_FMAXIMUM:
68     return true;
69   default:
70     return isTypeFoldingSupported(Opcode);
71   }
72 }
73 
74 static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
75                              MachineIRBuilder MIB) {
76   MachineRegisterInfo &MRI = MF.getRegInfo();
77 
78   for (MachineBasicBlock &MBB : MF) {
79     for (MachineInstr &I : MBB) {
80       const unsigned Opcode = I.getOpcode();
81       if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
82         unsigned ArgI = I.getNumOperands() - 1;
83         Register SrcReg = I.getOperand(ArgI).isReg()
84                               ? I.getOperand(ArgI).getReg()
85                               : Register(0);
86         SPIRVType *DefType =
87             SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
88         if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
89           report_fatal_error(
90               "cannot select G_UNMERGE_VALUES with a non-vector argument");
91         SPIRVType *ScalarType =
92             GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
93         for (unsigned i = 0; i < I.getNumDefs(); ++i) {
94           Register ResVReg = I.getOperand(i).getReg();
95           SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
96           if (!ResType) {
97             // There was no "assign type" actions, let's fix this now
98             ResType = ScalarType;
99             setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
100           }
101         }
102       } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
103                  I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
104         // Legalizer may have added a new instructions and introduced new
105         // registers, we must decorate them as if they were introduced in a
106         // non-automatic way
107         Register ResVReg = I.getOperand(0).getReg();
108         // Check if the register defined by the instruction is newly generated
109         // or already processed
110         if (MRI.getRegClassOrNull(ResVReg))
111           continue;
112         assert(GR->getSPIRVTypeForVReg(ResVReg) == nullptr);
113         // Check if we have type defined for operands of the new instruction
114         SPIRVType *ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
115         if (!ResVType)
116           continue;
117         // Set type & class
118         setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
119         // If this is a simple operation that is to be reduced by TableGen
120         // definition we must apply some of pre-legalizer rules here
121         if (isTypeFoldingSupported(Opcode)) {
122           insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
123           processInstr(I, MIB, MRI, GR);
124         }
125       }
126     }
127   }
128 }
129 
130 // Do a preorder traversal of the CFG starting from the BB |Start|.
131 // point. Calls |op| on each basic block encountered during the traversal.
132 void visit(MachineFunction &MF, MachineBasicBlock &Start,
133            std::function<void(MachineBasicBlock *)> op) {
134   std::stack<MachineBasicBlock *> ToVisit;
135   SmallPtrSet<MachineBasicBlock *, 8> Seen;
136 
137   ToVisit.push(&Start);
138   Seen.insert(ToVisit.top());
139   while (ToVisit.size() != 0) {
140     MachineBasicBlock *MBB = ToVisit.top();
141     ToVisit.pop();
142 
143     op(MBB);
144 
145     for (auto Succ : MBB->successors()) {
146       if (Seen.contains(Succ))
147         continue;
148       ToVisit.push(Succ);
149       Seen.insert(Succ);
150     }
151   }
152 }
153 
154 // Do a preorder traversal of the CFG starting from the given function's entry
155 // point. Calls |op| on each basic block encountered during the traversal.
156 void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
157   visit(MF, *MF.begin(), op);
158 }
159 
160 bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
161   // Initialize the type registry.
162   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
163   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
164   GR->setCurrentFunc(MF);
165   MachineIRBuilder MIB(MF);
166 
167   processNewInstrs(MF, GR, MIB);
168 
169   return true;
170 }
171 
172 INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
173                 false)
174 
175 char SPIRVPostLegalizer::ID = 0;
176 
177 FunctionPass *llvm::createSPIRVPostLegalizerPass() {
178   return new SPIRVPostLegalizer();
179 }
180