xref: /llvm-project/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp (revision edbb27ccb63402b591a459f4087434ea778c23a7)
1 //===-- FixupStatepointCallerSaved.cpp - Fixup caller saved registers  ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Statepoint instruction in deopt parameters contains values which are
12 /// meaningful to the runtime and should be able to be read at the moment the
13 /// call returns. So we can say that we need to encode the fact that these
14 /// values are "late read" by runtime. If we could express this notion for
15 /// register allocator it would produce the right form for us.
16 /// The need to fixup (i.e this pass) is specifically handling the fact that
17 /// we cannot describe such a late read for the register allocator.
18 /// Register allocator may put the value on a register clobbered by the call.
19 /// This pass forces the spill of such registers and replaces corresponding
20 /// statepoint operands to added spill slots.
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/StackMaps.h"
31 #include "llvm/CodeGen/TargetFrameLowering.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33 #include "llvm/IR/Statepoint.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Support/Debug.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "fixup-statepoint-caller-saved"
40 STATISTIC(NumSpilledRegisters, "Number of spilled register");
41 STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
42 STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
43 
44 static cl::opt<bool> FixupSCSExtendSlotSize(
45     "fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
46     cl::desc("Allow spill in spill slot of greater size than register size"),
47     cl::Hidden);
48 
49 namespace {
50 
51 class FixupStatepointCallerSaved : public MachineFunctionPass {
52 public:
53   static char ID;
54 
55   FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
56     initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
57   }
58 
59   void getAnalysisUsage(AnalysisUsage &AU) const override {
60     MachineFunctionPass::getAnalysisUsage(AU);
61   }
62 
63   StringRef getPassName() const override {
64     return "Fixup Statepoint Caller Saved";
65   }
66 
67   bool runOnMachineFunction(MachineFunction &MF) override;
68 };
69 } // End anonymous namespace.
70 
71 char FixupStatepointCallerSaved::ID = 0;
72 char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
73 
74 INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
75                       "Fixup Statepoint Caller Saved", false, false)
76 INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
77                     "Fixup Statepoint Caller Saved", false, false)
78 
79 // Utility function to get size of the register.
80 static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
81   const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
82   return TRI.getSpillSize(*RC);
83 }
84 
85 // Cache used frame indexes during statepoint re-write to re-use them in
86 // processing next statepoint instruction.
87 // Two strategies. One is to preserve the size of spill slot while another one
88 // extends the size of spill slots to reduce the number of them, causing
89 // the less total frame size. But unspill will have "implicit" any extend.
90 class FrameIndexesCache {
91 private:
92   struct FrameIndexesPerSize {
93     // List of used frame indexes during processing previous statepoints.
94     SmallVector<int, 8> Slots;
95     // Current index of un-used yet frame index.
96     unsigned Index = 0;
97   };
98   MachineFrameInfo &MFI;
99   const TargetRegisterInfo &TRI;
100   // Map size to list of frame indexes of this size. If the mode is
101   // FixupSCSExtendSlotSize then the key 0 is used to keep all frame indexes.
102   // If the size of required spill slot is greater than in a cache then the
103   // size will be increased.
104   DenseMap<unsigned, FrameIndexesPerSize> Cache;
105 
106 public:
107   FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
108       : MFI(MFI), TRI(TRI) {}
109   // Reset the current state of used frame indexes. After invocation of
110   // this function all frame indexes are available for allocation.
111   void reset() {
112     for (auto &It : Cache)
113       It.second.Index = 0;
114   }
115   // Get frame index to spill the register.
116   int getFrameIndex(Register Reg) {
117     unsigned Size = getRegisterSize(TRI, Reg);
118     // In FixupSCSExtendSlotSize mode the bucket with 0 index is used
119     // for all sizes.
120     unsigned Bucket = FixupSCSExtendSlotSize ? 0 : Size;
121     FrameIndexesPerSize &Line = Cache[Bucket];
122     if (Line.Index < Line.Slots.size()) {
123       int FI = Line.Slots[Line.Index++];
124       // If all sizes are kept together we probably need to extend the
125       // spill slot size.
126       if (MFI.getObjectSize(FI) < Size) {
127         MFI.setObjectSize(FI, Size);
128         MFI.setObjectAlignment(FI, Align(Size));
129         NumSpillSlotsExtended++;
130       }
131       return FI;
132     }
133     int FI = MFI.CreateSpillStackObject(Size, Size);
134     NumSpillSlotsAllocated++;
135     Line.Slots.push_back(FI);
136     ++Line.Index;
137     return FI;
138   }
139   // Sort all registers to spill in descendent order. In the
140   // FixupSCSExtendSlotSize mode it will minimize the total frame size.
141   // In non FixupSCSExtendSlotSize mode we can skip this step.
142   void sortRegisters(SmallVectorImpl<Register> &Regs) {
143     if (!FixupSCSExtendSlotSize)
144       return;
145     llvm::sort(Regs.begin(), Regs.end(), [&](Register &A, Register &B) {
146       return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
147     });
148   }
149 };
150 
151 // Describes the state of the current processing statepoint instruction.
152 class StatepointState {
153 private:
154   // statepoint instruction.
155   MachineInstr &MI;
156   MachineFunction &MF;
157   const TargetRegisterInfo &TRI;
158   const TargetInstrInfo &TII;
159   MachineFrameInfo &MFI;
160   // Mask with callee saved registers.
161   const uint32_t *Mask;
162   // Cache of frame indexes used on previous instruction processing.
163   FrameIndexesCache &CacheFI;
164   // Operands with physical registers requiring spilling.
165   SmallVector<unsigned, 8> OpsToSpill;
166   // Set of register to spill.
167   SmallVector<Register, 8> RegsToSpill;
168   // Map Register to Frame Slot index.
169   DenseMap<Register, int> RegToSlotIdx;
170 
171 public:
172   StatepointState(MachineInstr &MI, const uint32_t *Mask,
173                   FrameIndexesCache &CacheFI)
174       : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
175         TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
176         Mask(Mask), CacheFI(CacheFI) {}
177   // Return true if register is callee saved.
178   bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
179   // Iterates over statepoint meta args to find caller saver registers.
180   // Also cache the size of found registers.
181   // Returns true if caller save registers found.
182   bool findRegistersToSpill() {
183     SmallSet<Register, 8> VisitedRegs;
184     for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
185                   EndIdx = MI.getNumOperands();
186          Idx < EndIdx; ++Idx) {
187       MachineOperand &MO = MI.getOperand(Idx);
188       if (!MO.isReg() || MO.isImplicit())
189         continue;
190       Register Reg = MO.getReg();
191       assert(Reg.isPhysical() && "Only physical regs are expected");
192       if (isCalleeSaved(Reg))
193         continue;
194       if (VisitedRegs.insert(Reg).second)
195         RegsToSpill.push_back(Reg);
196       OpsToSpill.push_back(Idx);
197     }
198     CacheFI.sortRegisters(RegsToSpill);
199     return !RegsToSpill.empty();
200   }
201   // Spill all caller saved registers right before statepoint instruction.
202   // Remember frame index where register is spilled.
203   void spillRegisters() {
204     for (Register Reg : RegsToSpill) {
205       int FI = CacheFI.getFrameIndex(Reg);
206       const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
207       TII.storeRegToStackSlot(*MI.getParent(), MI, Reg, true /*is_Kill*/, FI,
208                               RC, &TRI);
209       NumSpilledRegisters++;
210       RegToSlotIdx[Reg] = FI;
211     }
212   }
213   // Re-write statepoint machine instruction to replace caller saved operands
214   // with indirect memory location (frame index).
215   void rewriteStatepoint() {
216     MachineInstr *NewMI =
217         MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
218     MachineInstrBuilder MIB(MF, NewMI);
219 
220     // Add End marker.
221     OpsToSpill.push_back(MI.getNumOperands());
222     unsigned CurOpIdx = 0;
223 
224     for (unsigned I = 0; I < MI.getNumOperands(); ++I) {
225       MachineOperand &MO = MI.getOperand(I);
226       if (I == OpsToSpill[CurOpIdx]) {
227         int FI = RegToSlotIdx[MO.getReg()];
228         MIB.addImm(StackMaps::IndirectMemRefOp);
229         MIB.addImm(getRegisterSize(TRI, MO.getReg()));
230         assert(MO.isReg() && "Should be register");
231         assert(MO.getReg().isPhysical() && "Should be physical register");
232         MIB.addFrameIndex(FI);
233         MIB.addImm(0);
234         ++CurOpIdx;
235       } else
236         MIB.add(MO);
237     }
238     assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
239     // Add mem operands.
240     NewMI->setMemRefs(MF, MI.memoperands());
241     for (auto It : RegToSlotIdx) {
242       int FrameIndex = It.second;
243       auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
244       auto *MMO = MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad,
245                                           getRegisterSize(TRI, It.first),
246                                           MFI.getObjectAlign(FrameIndex));
247       NewMI->addMemOperand(MF, MMO);
248     }
249     // Insert new statepoint and erase old one.
250     MI.getParent()->insert(MI, NewMI);
251     MI.eraseFromParent();
252   }
253 };
254 
255 class StatepointProcessor {
256 private:
257   MachineFunction &MF;
258   const TargetRegisterInfo &TRI;
259   FrameIndexesCache CacheFI;
260 
261 public:
262   StatepointProcessor(MachineFunction &MF)
263       : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
264         CacheFI(MF.getFrameInfo(), TRI) {}
265 
266   bool process(MachineInstr &MI) {
267     StatepointOpers SO(&MI);
268     uint64_t Flags = SO.getFlags();
269     // Do nothing for LiveIn, it supports all registers.
270     if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
271       return false;
272     CallingConv::ID CC = SO.getCallingConv();
273     const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
274     CacheFI.reset();
275     StatepointState SS(MI, Mask, CacheFI);
276 
277     if (!SS.findRegistersToSpill())
278       return false;
279 
280     SS.spillRegisters();
281     SS.rewriteStatepoint();
282     return true;
283   }
284 };
285 
286 bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
287   if (skipFunction(MF.getFunction()))
288     return false;
289 
290   const Function &F = MF.getFunction();
291   if (!F.hasGC())
292     return false;
293 
294   SmallVector<MachineInstr *, 16> Statepoints;
295   for (MachineBasicBlock &BB : MF)
296     for (MachineInstr &I : BB)
297       if (I.getOpcode() == TargetOpcode::STATEPOINT)
298         Statepoints.push_back(&I);
299 
300   if (Statepoints.empty())
301     return false;
302 
303   bool Changed = false;
304   StatepointProcessor SPP(MF);
305   for (MachineInstr *I : Statepoints)
306     Changed |= SPP.process(*I);
307   return Changed;
308 }
309