xref: /llvm-project/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp (revision 1c80a6ce5f2217c01fb40bd43bc5bf094c32278a)
1 //===-- FixupStatepointCallerSaved.cpp - Fixup caller saved registers  ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Statepoint instruction in deopt parameters contains values which are
12 /// meaningful to the runtime and should be able to be read at the moment the
13 /// call returns. So we can say that we need to encode the fact that these
14 /// values are "late read" by runtime. If we could express this notion for
15 /// register allocator it would produce the right form for us.
16 /// The need to fixup (i.e this pass) is specifically handling the fact that
17 /// we cannot describe such a late read for the register allocator.
18 /// Register allocator may put the value on a register clobbered by the call.
19 /// This pass forces the spill of such registers and replaces corresponding
20 /// statepoint operands to added spill slots.
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/StackMaps.h"
31 #include "llvm/CodeGen/TargetFrameLowering.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33 #include "llvm/IR/Statepoint.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Support/Debug.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "fixup-statepoint-caller-saved"
40 STATISTIC(NumSpilledRegisters, "Number of spilled register");
41 STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
42 STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
43 
44 static cl::opt<bool> FixupSCSExtendSlotSize(
45     "fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
46     cl::desc("Allow spill in spill slot of greater size than register size"),
47     cl::Hidden);
48 
49 static cl::opt<bool> PassGCPtrInCSR(
50     "fixup-allow-gcptr-in-csr", cl::Hidden, cl::init(false),
51     cl::desc("Allow passing GC Pointer arguments in callee saved registers"));
52 
53 static cl::opt<bool> EnableCopyProp(
54     "fixup-scs-enable-copy-propagation", cl::Hidden, cl::init(true),
55     cl::desc("Enable simple copy propagation during register reloading"));
56 
57 // This is purely debugging option.
58 // It may be handy for investigating statepoint spilling issues.
59 static cl::opt<unsigned> MaxStatepointsWithRegs(
60     "fixup-max-csr-statepoints", cl::Hidden,
61     cl::desc("Max number of statepoints allowed to pass GC Ptrs in registers"));
62 
63 namespace {
64 
65 class FixupStatepointCallerSaved : public MachineFunctionPass {
66 public:
67   static char ID;
68 
69   FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
70     initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
71   }
72 
73   void getAnalysisUsage(AnalysisUsage &AU) const override {
74     AU.setPreservesCFG();
75     MachineFunctionPass::getAnalysisUsage(AU);
76   }
77 
78   StringRef getPassName() const override {
79     return "Fixup Statepoint Caller Saved";
80   }
81 
82   bool runOnMachineFunction(MachineFunction &MF) override;
83 };
84 
85 } // End anonymous namespace.
86 
87 char FixupStatepointCallerSaved::ID = 0;
88 char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
89 
90 INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
91                       "Fixup Statepoint Caller Saved", false, false)
92 INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
93                     "Fixup Statepoint Caller Saved", false, false)
94 
95 // Utility function to get size of the register.
96 static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
97   const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
98   return TRI.getSpillSize(*RC);
99 }
100 
101 // Advance iterator to the next stack map entry
102 static MachineInstr::const_mop_iterator
103 advanceToNextStackMapElt(MachineInstr::const_mop_iterator MOI) {
104   if (MOI->isImm()) {
105     switch (MOI->getImm()) {
106     default:
107       llvm_unreachable("Unrecognized operand type.");
108     case StackMaps::DirectMemRefOp:
109       MOI += 2; // <Reg>, <Imm>
110       break;
111     case StackMaps::IndirectMemRefOp:
112       MOI += 3; // <Size>, <Reg>, <Imm>
113       break;
114     case StackMaps::ConstantOp:
115       MOI += 1;
116       break;
117     }
118   }
119   return ++MOI;
120 }
121 
122 // Return statepoint GC args as a set
123 static SmallSet<Register, 8> collectGCRegs(MachineInstr &MI) {
124   StatepointOpers SO(&MI);
125   unsigned NumDeoptIdx = SO.getNumDeoptArgsIdx();
126   unsigned NumDeoptArgs = MI.getOperand(NumDeoptIdx).getImm();
127   MachineInstr::const_mop_iterator MOI(MI.operands_begin() + NumDeoptIdx + 1),
128       MOE(MI.operands_end());
129 
130   // Skip deopt args
131   while (NumDeoptArgs--)
132     MOI = advanceToNextStackMapElt(MOI);
133 
134   SmallSet<Register, 8> Result;
135   while (MOI != MOE) {
136     if (MOI->isReg() && !MOI->isImplicit())
137       Result.insert(MOI->getReg());
138     MOI = advanceToNextStackMapElt(MOI);
139   }
140   return Result;
141 }
142 
143 // Try to eliminate redundant copy to register which we're going to
144 // spill, i.e. try to change:
145 //    X = COPY Y
146 //    SPILL X
147 //  to
148 //    SPILL Y
149 //  If there are no uses of X between copy and STATEPOINT, that COPY
150 //  may be eliminated.
151 //  Reg - register we're about to spill
152 //  RI - On entry points to statepoint.
153 //       On successful copy propagation set to new spill point.
154 //  IsKill - set to true if COPY is Kill (there are no uses of Y)
155 //  Returns either found source copy register or original one.
156 static Register performCopyPropagation(Register Reg,
157                                        MachineBasicBlock::iterator &RI,
158                                        bool &IsKill, const TargetInstrInfo &TII,
159                                        const TargetRegisterInfo &TRI) {
160   // First check if statepoint itself uses Reg in non-meta operands.
161   int Idx = RI->findRegisterUseOperandIdx(Reg, false, &TRI);
162   if (Idx >= 0 && (unsigned)Idx < StatepointOpers(&*RI).getNumDeoptArgsIdx()) {
163     IsKill = false;
164     return Reg;
165   }
166 
167   if (!EnableCopyProp)
168     return Reg;
169 
170   MachineBasicBlock *MBB = RI->getParent();
171   MachineBasicBlock::reverse_iterator E = MBB->rend();
172   MachineInstr *Def = nullptr, *Use = nullptr;
173   for (auto It = ++(RI.getReverse()); It != E; ++It) {
174     if (It->readsRegister(Reg, &TRI) && !Use)
175       Use = &*It;
176     if (It->modifiesRegister(Reg, &TRI)) {
177       Def = &*It;
178       break;
179     }
180   }
181 
182   if (!Def)
183     return Reg;
184 
185   auto DestSrc = TII.isCopyInstr(*Def);
186   if (!DestSrc || DestSrc->Destination->getReg() != Reg)
187     return Reg;
188 
189   Register SrcReg = DestSrc->Source->getReg();
190 
191   if (getRegisterSize(TRI, Reg) != getRegisterSize(TRI, SrcReg))
192     return Reg;
193 
194   LLVM_DEBUG(dbgs() << "spillRegisters: perform copy propagation "
195                     << printReg(Reg, &TRI) << " -> " << printReg(SrcReg, &TRI)
196                     << "\n");
197 
198   // Insert spill immediately after Def
199   RI = ++MachineBasicBlock::iterator(Def);
200   IsKill = DestSrc->Source->isKill();
201 
202   // There are no uses of original register between COPY and STATEPOINT.
203   // There can't be any after STATEPOINT, so we can eliminate Def.
204   if (!Use) {
205     LLVM_DEBUG(dbgs() << "spillRegisters: removing dead copy " << *Def);
206     Def->eraseFromParent();
207   }
208   return SrcReg;
209 }
210 
211 namespace {
212 // Pair {Register, FrameIndex}
213 using RegSlotPair = std::pair<Register, int>;
214 
215 // Keeps track of what reloads were inserted in MBB.
216 class RegReloadCache {
217   using ReloadSet = SmallSet<RegSlotPair, 8>;
218   DenseMap<const MachineBasicBlock *, ReloadSet> Reloads;
219 
220 public:
221   RegReloadCache() = default;
222 
223   // Record reload of Reg from FI in block MBB
224   void recordReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
225     RegSlotPair RSP(Reg, FI);
226     auto Res = Reloads[MBB].insert(RSP);
227     assert(Res.second && "reload already exists");
228   }
229 
230   // Does basic block MBB contains reload of Reg from FI?
231   bool hasReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
232     RegSlotPair RSP(Reg, FI);
233     return Reloads.count(MBB) && Reloads[MBB].count(RSP);
234   }
235 };
236 
237 // Cache used frame indexes during statepoint re-write to re-use them in
238 // processing next statepoint instruction.
239 // Two strategies. One is to preserve the size of spill slot while another one
240 // extends the size of spill slots to reduce the number of them, causing
241 // the less total frame size. But unspill will have "implicit" any extend.
242 class FrameIndexesCache {
243 private:
244   struct FrameIndexesPerSize {
245     // List of used frame indexes during processing previous statepoints.
246     SmallVector<int, 8> Slots;
247     // Current index of un-used yet frame index.
248     unsigned Index = 0;
249   };
250   MachineFrameInfo &MFI;
251   const TargetRegisterInfo &TRI;
252   // Map size to list of frame indexes of this size. If the mode is
253   // FixupSCSExtendSlotSize then the key 0 is used to keep all frame indexes.
254   // If the size of required spill slot is greater than in a cache then the
255   // size will be increased.
256   DenseMap<unsigned, FrameIndexesPerSize> Cache;
257 
258   // Keeps track of slots reserved for the shared landing pad processing.
259   // Initialized from GlobalIndices for the current EHPad.
260   SmallSet<int, 8> ReservedSlots;
261 
262   // Landing pad can be destination of several statepoints. Every register
263   // defined by such statepoints must be spilled to the same stack slot.
264   // This map keeps that information.
265   DenseMap<const MachineBasicBlock *, SmallVector<RegSlotPair, 8>>
266       GlobalIndices;
267 
268   FrameIndexesPerSize &getCacheBucket(unsigned Size) {
269     // In FixupSCSExtendSlotSize mode the bucket with 0 index is used
270     // for all sizes.
271     return Cache[FixupSCSExtendSlotSize ? 0 : Size];
272   }
273 
274 public:
275   FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
276       : MFI(MFI), TRI(TRI) {}
277   // Reset the current state of used frame indexes. After invocation of
278   // this function all frame indexes are available for allocation with
279   // the exception of slots reserved for landing pad processing (if any).
280   void reset(const MachineBasicBlock *EHPad) {
281     for (auto &It : Cache)
282       It.second.Index = 0;
283 
284     ReservedSlots.clear();
285     if (EHPad && GlobalIndices.count(EHPad))
286       for (auto &RSP : GlobalIndices[EHPad])
287         ReservedSlots.insert(RSP.second);
288   }
289 
290   // Get frame index to spill the register.
291   int getFrameIndex(Register Reg, MachineBasicBlock *EHPad) {
292     // Check if slot for Reg is already reserved at EHPad.
293     auto It = GlobalIndices.find(EHPad);
294     if (It != GlobalIndices.end()) {
295       auto &Vec = It->second;
296       auto Idx = llvm::find_if(
297           Vec, [Reg](RegSlotPair &RSP) { return Reg == RSP.first; });
298       if (Idx != Vec.end()) {
299         int FI = Idx->second;
300         LLVM_DEBUG(dbgs() << "Found global FI " << FI << " for register "
301                           << printReg(Reg, &TRI) << " at "
302                           << printMBBReference(*EHPad) << "\n");
303         assert(ReservedSlots.count(FI) && "using unreserved slot");
304         return FI;
305       }
306     }
307 
308     unsigned Size = getRegisterSize(TRI, Reg);
309     FrameIndexesPerSize &Line = getCacheBucket(Size);
310     while (Line.Index < Line.Slots.size()) {
311       int FI = Line.Slots[Line.Index++];
312       if (ReservedSlots.count(FI))
313         continue;
314       // If all sizes are kept together we probably need to extend the
315       // spill slot size.
316       if (MFI.getObjectSize(FI) < Size) {
317         MFI.setObjectSize(FI, Size);
318         MFI.setObjectAlignment(FI, Align(Size));
319         NumSpillSlotsExtended++;
320       }
321       return FI;
322     }
323     int FI = MFI.CreateSpillStackObject(Size, Align(Size));
324     NumSpillSlotsAllocated++;
325     Line.Slots.push_back(FI);
326     ++Line.Index;
327 
328     // Remember assignment {Reg, FI} for EHPad
329     if (EHPad) {
330       GlobalIndices[EHPad].push_back(std::make_pair(Reg, FI));
331       LLVM_DEBUG(dbgs() << "Reserved FI " << FI << " for spilling reg "
332                         << printReg(Reg, &TRI) << " at landing pad "
333                         << printMBBReference(*EHPad) << "\n");
334     }
335 
336     return FI;
337   }
338 
339   // Sort all registers to spill in descendent order. In the
340   // FixupSCSExtendSlotSize mode it will minimize the total frame size.
341   // In non FixupSCSExtendSlotSize mode we can skip this step.
342   void sortRegisters(SmallVectorImpl<Register> &Regs) {
343     if (!FixupSCSExtendSlotSize)
344       return;
345     llvm::sort(Regs.begin(), Regs.end(), [&](Register &A, Register &B) {
346       return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
347     });
348   }
349 };
350 
351 // Describes the state of the current processing statepoint instruction.
352 class StatepointState {
353 private:
354   // statepoint instruction.
355   MachineInstr &MI;
356   MachineFunction &MF;
357   // If non-null then statepoint is invoke, and this points to the landing pad.
358   MachineBasicBlock *EHPad;
359   const TargetRegisterInfo &TRI;
360   const TargetInstrInfo &TII;
361   MachineFrameInfo &MFI;
362   // Mask with callee saved registers.
363   const uint32_t *Mask;
364   // Cache of frame indexes used on previous instruction processing.
365   FrameIndexesCache &CacheFI;
366   bool AllowGCPtrInCSR;
367   // Operands with physical registers requiring spilling.
368   SmallVector<unsigned, 8> OpsToSpill;
369   // Set of register to spill.
370   SmallVector<Register, 8> RegsToSpill;
371   // Set of registers to reload after statepoint.
372   SmallVector<Register, 8> RegsToReload;
373   // Map Register to Frame Slot index.
374   DenseMap<Register, int> RegToSlotIdx;
375 
376 public:
377   StatepointState(MachineInstr &MI, const uint32_t *Mask,
378                   FrameIndexesCache &CacheFI, bool AllowGCPtrInCSR)
379       : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
380         TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
381         Mask(Mask), CacheFI(CacheFI), AllowGCPtrInCSR(AllowGCPtrInCSR) {
382 
383     // Find statepoint's landing pad, if any.
384     EHPad = nullptr;
385     MachineBasicBlock *MBB = MI.getParent();
386     // Invoke statepoint must be last one in block.
387     bool Last = std::none_of(++MI.getIterator(), MBB->end().getInstrIterator(),
388                              [](MachineInstr &I) {
389                                return I.getOpcode() == TargetOpcode::STATEPOINT;
390                              });
391 
392     if (!Last)
393       return;
394 
395     auto IsEHPad = [](MachineBasicBlock *B) { return B->isEHPad(); };
396 
397     assert(llvm::count_if(MBB->successors(), IsEHPad) < 2 && "multiple EHPads");
398 
399     auto It = llvm::find_if(MBB->successors(), IsEHPad);
400     if (It != MBB->succ_end())
401       EHPad = *It;
402   }
403 
404   MachineBasicBlock *getEHPad() const { return EHPad; }
405 
406   // Return true if register is callee saved.
407   bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
408 
409   // Iterates over statepoint meta args to find caller saver registers.
410   // Also cache the size of found registers.
411   // Returns true if caller save registers found.
412   bool findRegistersToSpill() {
413     SmallSet<Register, 8> VisitedRegs;
414     SmallSet<Register, 8> GCRegs = collectGCRegs(MI);
415     for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
416                   EndIdx = MI.getNumOperands();
417          Idx < EndIdx; ++Idx) {
418       MachineOperand &MO = MI.getOperand(Idx);
419       if (!MO.isReg() || MO.isImplicit())
420         continue;
421       Register Reg = MO.getReg();
422       assert(Reg.isPhysical() && "Only physical regs are expected");
423 
424       if (isCalleeSaved(Reg) && (AllowGCPtrInCSR || !is_contained(GCRegs, Reg)))
425         continue;
426 
427       LLVM_DEBUG(dbgs() << "Will spill " << printReg(Reg, &TRI) << " at index "
428                         << Idx << "\n");
429 
430       if (VisitedRegs.insert(Reg).second)
431         RegsToSpill.push_back(Reg);
432       OpsToSpill.push_back(Idx);
433     }
434     CacheFI.sortRegisters(RegsToSpill);
435     return !RegsToSpill.empty();
436   }
437 
438   // Spill all caller saved registers right before statepoint instruction.
439   // Remember frame index where register is spilled.
440   void spillRegisters() {
441     for (Register Reg : RegsToSpill) {
442       int FI = CacheFI.getFrameIndex(Reg, EHPad);
443       const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
444 
445       NumSpilledRegisters++;
446       RegToSlotIdx[Reg] = FI;
447 
448       LLVM_DEBUG(dbgs() << "Spilling " << printReg(Reg, &TRI) << " to FI " << FI
449                         << "\n");
450 
451       // Perform trivial copy propagation
452       bool IsKill = true;
453       MachineBasicBlock::iterator InsertBefore(MI);
454       Reg = performCopyPropagation(Reg, InsertBefore, IsKill, TII, TRI);
455 
456       LLVM_DEBUG(dbgs() << "Insert spill before " << *InsertBefore);
457       TII.storeRegToStackSlot(*MI.getParent(), InsertBefore, Reg, IsKill, FI,
458                               RC, &TRI);
459     }
460   }
461 
462   void insertReloadBefore(unsigned Reg, MachineBasicBlock::iterator It,
463                           MachineBasicBlock *MBB) {
464     const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
465     int FI = RegToSlotIdx[Reg];
466     if (It != MBB->end()) {
467       TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
468       return;
469     }
470 
471     // To insert reload at the end of MBB, insert it before last instruction
472     // and then swap them.
473     assert(MBB->begin() != MBB->end() && "Empty block");
474     --It;
475     TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
476     MachineInstr *Reload = It->getPrevNode();
477     int Dummy = 0;
478     assert(TII.isLoadFromStackSlot(*Reload, Dummy) == Reg);
479     assert(Dummy == FI);
480     MBB->remove(Reload);
481     MBB->insertAfter(It, Reload);
482   }
483 
484   // Insert reloads of (relocated) registers spilled in statepoint.
485   void insertReloads(MachineInstr *NewStatepoint, RegReloadCache &RC) {
486     MachineBasicBlock *MBB = NewStatepoint->getParent();
487     auto InsertPoint = std::next(NewStatepoint->getIterator());
488 
489     for (auto Reg : RegsToReload) {
490       insertReloadBefore(Reg, InsertPoint, MBB);
491       LLVM_DEBUG(dbgs() << "Reloading " << printReg(Reg, &TRI) << " from FI "
492                         << RegToSlotIdx[Reg] << " after statepoint\n");
493 
494       if (EHPad && !RC.hasReload(Reg, RegToSlotIdx[Reg], EHPad)) {
495         RC.recordReload(Reg, RegToSlotIdx[Reg], EHPad);
496         auto EHPadInsertPoint = EHPad->SkipPHIsLabelsAndDebug(EHPad->begin());
497         insertReloadBefore(Reg, EHPadInsertPoint, EHPad);
498         LLVM_DEBUG(dbgs() << "...also reload at EHPad "
499                           << printMBBReference(*EHPad) << "\n");
500       }
501     }
502   }
503 
504   // Re-write statepoint machine instruction to replace caller saved operands
505   // with indirect memory location (frame index).
506   MachineInstr *rewriteStatepoint() {
507     MachineInstr *NewMI =
508         MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
509     MachineInstrBuilder MIB(MF, NewMI);
510 
511     unsigned NumOps = MI.getNumOperands();
512 
513     // New indices for the remaining defs.
514     SmallVector<unsigned, 8> NewIndices;
515     unsigned NumDefs = MI.getNumDefs();
516     for (unsigned I = 0; I < NumDefs; ++I) {
517       MachineOperand &DefMO = MI.getOperand(I);
518       assert(DefMO.isReg() && DefMO.isDef() && "Expected Reg Def operand");
519       Register Reg = DefMO.getReg();
520       if (!AllowGCPtrInCSR) {
521         assert(is_contained(RegsToSpill, Reg));
522         RegsToReload.push_back(Reg);
523       } else {
524         if (isCalleeSaved(Reg)) {
525           NewIndices.push_back(NewMI->getNumOperands());
526           MIB.addReg(Reg, RegState::Define);
527         } else {
528           NewIndices.push_back(NumOps);
529           RegsToReload.push_back(Reg);
530         }
531       }
532     }
533 
534     // Add End marker.
535     OpsToSpill.push_back(MI.getNumOperands());
536     unsigned CurOpIdx = 0;
537 
538     for (unsigned I = NumDefs; I < MI.getNumOperands(); ++I) {
539       MachineOperand &MO = MI.getOperand(I);
540       if (I == OpsToSpill[CurOpIdx]) {
541         int FI = RegToSlotIdx[MO.getReg()];
542         MIB.addImm(StackMaps::IndirectMemRefOp);
543         MIB.addImm(getRegisterSize(TRI, MO.getReg()));
544         assert(MO.isReg() && "Should be register");
545         assert(MO.getReg().isPhysical() && "Should be physical register");
546         MIB.addFrameIndex(FI);
547         MIB.addImm(0);
548         ++CurOpIdx;
549       } else {
550         MIB.add(MO);
551         unsigned OldDef;
552         if (AllowGCPtrInCSR && MI.isRegTiedToDefOperand(I, &OldDef)) {
553           assert(OldDef < NumDefs);
554           assert(NewIndices[OldDef] < NumOps);
555           MIB->tieOperands(NewIndices[OldDef], MIB->getNumOperands() - 1);
556         }
557       }
558     }
559     assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
560     // Add mem operands.
561     NewMI->setMemRefs(MF, MI.memoperands());
562     for (auto It : RegToSlotIdx) {
563       Register R = It.first;
564       int FrameIndex = It.second;
565       auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
566       MachineMemOperand::Flags Flags = MachineMemOperand::MOLoad;
567       if (is_contained(RegsToReload, R))
568         Flags |= MachineMemOperand::MOStore;
569       auto *MMO =
570           MF.getMachineMemOperand(PtrInfo, Flags, getRegisterSize(TRI, R),
571                                   MFI.getObjectAlign(FrameIndex));
572       NewMI->addMemOperand(MF, MMO);
573     }
574 
575     // Insert new statepoint and erase old one.
576     MI.getParent()->insert(MI, NewMI);
577 
578     LLVM_DEBUG(dbgs() << "rewritten statepoint to : " << *NewMI << "\n");
579     MI.eraseFromParent();
580     return NewMI;
581   }
582 };
583 
584 class StatepointProcessor {
585 private:
586   MachineFunction &MF;
587   const TargetRegisterInfo &TRI;
588   FrameIndexesCache CacheFI;
589   RegReloadCache ReloadCache;
590 
591 public:
592   StatepointProcessor(MachineFunction &MF)
593       : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
594         CacheFI(MF.getFrameInfo(), TRI) {}
595 
596   bool process(MachineInstr &MI, bool AllowGCPtrInCSR) {
597     StatepointOpers SO(&MI);
598     uint64_t Flags = SO.getFlags();
599     // Do nothing for LiveIn, it supports all registers.
600     if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
601       return false;
602     LLVM_DEBUG(dbgs() << "\nMBB " << MI.getParent()->getNumber() << " "
603                       << MI.getParent()->getName() << " : process statepoint "
604                       << MI);
605     CallingConv::ID CC = SO.getCallingConv();
606     const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
607     StatepointState SS(MI, Mask, CacheFI, AllowGCPtrInCSR);
608     CacheFI.reset(SS.getEHPad());
609 
610     if (!SS.findRegistersToSpill())
611       return false;
612 
613     SS.spillRegisters();
614     auto *NewStatepoint = SS.rewriteStatepoint();
615     SS.insertReloads(NewStatepoint, ReloadCache);
616     return true;
617   }
618 };
619 } // namespace
620 
621 bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
622   if (skipFunction(MF.getFunction()))
623     return false;
624 
625   const Function &F = MF.getFunction();
626   if (!F.hasGC())
627     return false;
628 
629   SmallVector<MachineInstr *, 16> Statepoints;
630   for (MachineBasicBlock &BB : MF)
631     for (MachineInstr &I : BB)
632       if (I.getOpcode() == TargetOpcode::STATEPOINT)
633         Statepoints.push_back(&I);
634 
635   if (Statepoints.empty())
636     return false;
637 
638   bool Changed = false;
639   StatepointProcessor SPP(MF);
640   unsigned NumStatepoints = 0;
641   bool AllowGCPtrInCSR = PassGCPtrInCSR;
642   for (MachineInstr *I : Statepoints) {
643     ++NumStatepoints;
644     if (MaxStatepointsWithRegs.getNumOccurrences() &&
645         NumStatepoints >= MaxStatepointsWithRegs)
646       AllowGCPtrInCSR = false;
647     Changed |= SPP.process(*I, AllowGCPtrInCSR);
648   }
649   return Changed;
650 }
651