xref: /llvm-project/llvm/lib/Target/BPF/BPFMIPeephole.cpp (revision 8aae191cb6ad1f2dfc468975e4f5e564fea3cbfd)
1 //===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs peephole optimizations to cleanup ugly code sequences at
10 // MachineInstruction layer.
11 //
12 // Currently, there are two optimizations implemented:
13 //  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14 //    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15 //    could prove the subregisters is defined by 32-bit operations in which
16 //    case the upper half of the underlying 64-bit registers were zeroed
17 //    implicitly.
18 //
19 //  - One post-RA PreEmit pass to do final cleanup on some redundant
20 //    instructions generated due to bad RA on subregister.
21 //===----------------------------------------------------------------------===//
22 
23 #include "BPF.h"
24 #include "BPFInstrInfo.h"
25 #include "BPFTargetMachine.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/CodeGen/LivePhysRegs.h"
29 #include "llvm/CodeGen/MachineFrameInfo.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstrBuilder.h"
32 #include "llvm/CodeGen/MachineRegisterInfo.h"
33 #include "llvm/Support/Debug.h"
34 #include <set>
35 
36 using namespace llvm;
37 
38 #define DEBUG_TYPE "bpf-mi-zext-elim"
39 
40 static cl::opt<int> GotolAbsLowBound("gotol-abs-low-bound", cl::Hidden,
41   cl::init(INT16_MAX >> 1), cl::desc("Specify gotol lower bound"));
42 
43 STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
44 
45 namespace {
46 
47 struct BPFMIPeephole : public MachineFunctionPass {
48 
49   static char ID;
50   const BPFInstrInfo *TII;
51   MachineFunction *MF;
52   MachineRegisterInfo *MRI;
53 
54   BPFMIPeephole() : MachineFunctionPass(ID) {
55     initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
56   }
57 
58 private:
59   // Initialize class variables.
60   void initialize(MachineFunction &MFParm);
61 
62   bool isCopyFrom32Def(MachineInstr *CopyMI);
63   bool isInsnFrom32Def(MachineInstr *DefInsn);
64   bool isPhiFrom32Def(MachineInstr *MovMI);
65   bool isMovFrom32Def(MachineInstr *MovMI);
66   bool eliminateZExtSeq();
67   bool eliminateZExt();
68 
69   std::set<MachineInstr *> PhiInsns;
70 
71 public:
72 
73   // Main entry point for this pass.
74   bool runOnMachineFunction(MachineFunction &MF) override {
75     if (skipFunction(MF.getFunction()))
76       return false;
77 
78     initialize(MF);
79 
80     // First try to eliminate (zext, lshift, rshift) and then
81     // try to eliminate zext.
82     bool ZExtSeqExist, ZExtExist;
83     ZExtSeqExist = eliminateZExtSeq();
84     ZExtExist = eliminateZExt();
85     return ZExtSeqExist || ZExtExist;
86   }
87 };
88 
89 // Initialize class variables.
90 void BPFMIPeephole::initialize(MachineFunction &MFParm) {
91   MF = &MFParm;
92   MRI = &MF->getRegInfo();
93   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
94   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
95 }
96 
97 bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
98 {
99   MachineOperand &opnd = CopyMI->getOperand(1);
100 
101   if (!opnd.isReg())
102     return false;
103 
104   // Return false if getting value from a 32bit physical register.
105   // Most likely, this physical register is aliased to
106   // function call return value or current function parameters.
107   Register Reg = opnd.getReg();
108   if (!Reg.isVirtual())
109     return false;
110 
111   if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
112     return false;
113 
114   MachineInstr *DefInsn = MRI->getVRegDef(Reg);
115   if (!isInsnFrom32Def(DefInsn))
116     return false;
117 
118   return true;
119 }
120 
121 bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
122 {
123   for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
124     MachineOperand &opnd = PhiMI->getOperand(i);
125 
126     if (!opnd.isReg())
127       return false;
128 
129     MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
130     if (!PhiDef)
131       return false;
132     if (PhiDef->isPHI()) {
133       if (!PhiInsns.insert(PhiDef).second)
134         return false;
135       if (!isPhiFrom32Def(PhiDef))
136         return false;
137     }
138     if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
139       return false;
140   }
141 
142   return true;
143 }
144 
145 // The \p DefInsn instruction defines a virtual register.
146 bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
147 {
148   if (!DefInsn)
149     return false;
150 
151   if (DefInsn->isPHI()) {
152     if (!PhiInsns.insert(DefInsn).second)
153       return false;
154     if (!isPhiFrom32Def(DefInsn))
155       return false;
156   } else if (DefInsn->getOpcode() == BPF::COPY) {
157     if (!isCopyFrom32Def(DefInsn))
158       return false;
159   }
160 
161   return true;
162 }
163 
164 bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
165 {
166   MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
167 
168   LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
169   LLVM_DEBUG(DefInsn->dump());
170 
171   PhiInsns.clear();
172   if (!isInsnFrom32Def(DefInsn))
173     return false;
174 
175   LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
176 
177   return true;
178 }
179 
180 bool BPFMIPeephole::eliminateZExtSeq() {
181   MachineInstr* ToErase = nullptr;
182   bool Eliminated = false;
183 
184   for (MachineBasicBlock &MBB : *MF) {
185     for (MachineInstr &MI : MBB) {
186       // If the previous instruction was marked for elimination, remove it now.
187       if (ToErase) {
188         ToErase->eraseFromParent();
189         ToErase = nullptr;
190       }
191 
192       // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
193       //
194       //   MOV_32_64 rB, wA
195       //   SLL_ri    rB, rB, 32
196       //   SRL_ri    rB, rB, 32
197       if (MI.getOpcode() == BPF::SRL_ri &&
198           MI.getOperand(2).getImm() == 32) {
199         Register DstReg = MI.getOperand(0).getReg();
200         Register ShfReg = MI.getOperand(1).getReg();
201         MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
202 
203         LLVM_DEBUG(dbgs() << "Starting SRL found:");
204         LLVM_DEBUG(MI.dump());
205 
206         if (!SllMI ||
207             SllMI->isPHI() ||
208             SllMI->getOpcode() != BPF::SLL_ri ||
209             SllMI->getOperand(2).getImm() != 32)
210           continue;
211 
212         LLVM_DEBUG(dbgs() << "  SLL found:");
213         LLVM_DEBUG(SllMI->dump());
214 
215         MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
216         if (!MovMI ||
217             MovMI->isPHI() ||
218             MovMI->getOpcode() != BPF::MOV_32_64)
219           continue;
220 
221         LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
222         LLVM_DEBUG(MovMI->dump());
223 
224         Register SubReg = MovMI->getOperand(1).getReg();
225         if (!isMovFrom32Def(MovMI)) {
226           LLVM_DEBUG(dbgs()
227                      << "  One ZExt elim sequence failed qualifying elim.\n");
228           continue;
229         }
230 
231         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
232           .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
233 
234         SllMI->eraseFromParent();
235         MovMI->eraseFromParent();
236         // MI is the right shift, we can't erase it in it's own iteration.
237         // Mark it to ToErase, and erase in the next iteration.
238         ToErase = &MI;
239         ZExtElemNum++;
240         Eliminated = true;
241       }
242     }
243   }
244 
245   return Eliminated;
246 }
247 
248 bool BPFMIPeephole::eliminateZExt() {
249   MachineInstr* ToErase = nullptr;
250   bool Eliminated = false;
251 
252   for (MachineBasicBlock &MBB : *MF) {
253     for (MachineInstr &MI : MBB) {
254       // If the previous instruction was marked for elimination, remove it now.
255       if (ToErase) {
256         ToErase->eraseFromParent();
257         ToErase = nullptr;
258       }
259 
260       if (MI.getOpcode() != BPF::MOV_32_64)
261         continue;
262 
263       // Eliminate MOV_32_64 if possible.
264       //   MOV_32_64 rA, wB
265       //
266       // If wB has been zero extended, replace it with a SUBREG_TO_REG.
267       // This is to workaround BPF programs where pkt->{data, data_end}
268       // is encoded as u32, but actually the verifier populates them
269       // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
270       LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
271       LLVM_DEBUG(MI.dump());
272 
273       if (!isMovFrom32Def(&MI))
274         continue;
275 
276       LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
277 
278       Register dst = MI.getOperand(0).getReg();
279       Register src = MI.getOperand(1).getReg();
280 
281       // Build a SUBREG_TO_REG instruction.
282       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
283         .addImm(0).addReg(src).addImm(BPF::sub_32);
284 
285       ToErase = &MI;
286       Eliminated = true;
287     }
288   }
289 
290   return Eliminated;
291 }
292 
293 } // end default namespace
294 
295 INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
296                 "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
297                 false, false)
298 
299 char BPFMIPeephole::ID = 0;
300 FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
301 
302 STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
303 
304 namespace {
305 
306 struct BPFMIPreEmitPeephole : public MachineFunctionPass {
307 
308   static char ID;
309   MachineFunction *MF;
310   const TargetRegisterInfo *TRI;
311   const BPFInstrInfo *TII;
312   bool SupportGotol;
313 
314   BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
315     initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
316   }
317 
318 private:
319   // Initialize class variables.
320   void initialize(MachineFunction &MFParm);
321 
322   bool in16BitRange(int Num);
323   bool eliminateRedundantMov();
324   bool adjustBranch();
325   bool insertMissingCallerSavedSpills();
326   bool removeMayGotoZero();
327 
328 public:
329 
330   // Main entry point for this pass.
331   bool runOnMachineFunction(MachineFunction &MF) override {
332     if (skipFunction(MF.getFunction()))
333       return false;
334 
335     initialize(MF);
336 
337     bool Changed;
338     Changed = eliminateRedundantMov();
339     if (SupportGotol)
340       Changed = adjustBranch() || Changed;
341     Changed |= insertMissingCallerSavedSpills();
342     Changed |= removeMayGotoZero();
343     return Changed;
344   }
345 };
346 
347 // Initialize class variables.
348 void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
349   MF = &MFParm;
350   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
351   TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
352   SupportGotol = MF->getSubtarget<BPFSubtarget>().hasGotol();
353   LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
354 }
355 
356 bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
357   MachineInstr* ToErase = nullptr;
358   bool Eliminated = false;
359 
360   for (MachineBasicBlock &MBB : *MF) {
361     for (MachineInstr &MI : MBB) {
362       // If the previous instruction was marked for elimination, remove it now.
363       if (ToErase) {
364         LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
365         LLVM_DEBUG(ToErase->dump());
366         ToErase->eraseFromParent();
367         ToErase = nullptr;
368       }
369 
370       // Eliminate identical move:
371       //
372       //   MOV rA, rA
373       //
374       // Note that we cannot remove
375       //   MOV_32_64  rA, wA
376       //   MOV_rr_32  wA, wA
377       // as these two instructions having side effects, zeroing out
378       // top 32 bits of rA.
379       unsigned Opcode = MI.getOpcode();
380       if (Opcode == BPF::MOV_rr) {
381         Register dst = MI.getOperand(0).getReg();
382         Register src = MI.getOperand(1).getReg();
383 
384         if (dst != src)
385           continue;
386 
387         ToErase = &MI;
388         RedundantMovElemNum++;
389         Eliminated = true;
390       }
391     }
392   }
393 
394   return Eliminated;
395 }
396 
397 bool BPFMIPreEmitPeephole::in16BitRange(int Num) {
398   // Well, the cut-off is not precisely at 16bit range since
399   // new codes are added during the transformation. So let us
400   // a little bit conservative.
401   return Num >= -GotolAbsLowBound && Num <= GotolAbsLowBound;
402 }
403 
404 // Before cpu=v4, only 16bit branch target offset (-0x8000 to 0x7fff)
405 // is supported for both unconditional (JMP) and condition (JEQ, JSGT,
406 // etc.) branches. In certain cases, e.g., full unrolling, the branch
407 // target offset might exceed 16bit range. If this happens, the llvm
408 // will generate incorrect code as the offset is truncated to 16bit.
409 //
410 // To fix this rare case, a new insn JMPL is introduced. This new
411 // insn supports supports 32bit branch target offset. The compiler
412 // does not use this insn during insn selection. Rather, BPF backend
413 // will estimate the branch target offset and do JMP -> JMPL and
414 // JEQ -> JEQ + JMPL conversion if the estimated branch target offset
415 // is beyond 16bit.
416 bool BPFMIPreEmitPeephole::adjustBranch() {
417   bool Changed = false;
418   int CurrNumInsns = 0;
419   DenseMap<MachineBasicBlock *, int> SoFarNumInsns;
420   DenseMap<MachineBasicBlock *, MachineBasicBlock *> FollowThroughBB;
421   std::vector<MachineBasicBlock *> MBBs;
422 
423   MachineBasicBlock *PrevBB = nullptr;
424   for (MachineBasicBlock &MBB : *MF) {
425     // MBB.size() is the number of insns in this basic block, including some
426     // debug info, e.g., DEBUG_VALUE, so we may over-count a little bit.
427     // Typically we have way more normal insns than DEBUG_VALUE insns.
428     // Also, if we indeed need to convert conditional branch like JEQ to
429     // JEQ + JMPL, we actually introduced some new insns like below.
430     CurrNumInsns += (int)MBB.size();
431     SoFarNumInsns[&MBB] = CurrNumInsns;
432     if (PrevBB != nullptr)
433       FollowThroughBB[PrevBB] = &MBB;
434     PrevBB = &MBB;
435     // A list of original BBs to make later traveral easier.
436     MBBs.push_back(&MBB);
437   }
438   FollowThroughBB[PrevBB] = nullptr;
439 
440   for (unsigned i = 0; i < MBBs.size(); i++) {
441     // We have four cases here:
442     //  (1). no terminator, simple follow through.
443     //  (2). jmp to another bb.
444     //  (3). conditional jmp to another bb or follow through.
445     //  (4). conditional jmp followed by an unconditional jmp.
446     MachineInstr *CondJmp = nullptr, *UncondJmp = nullptr;
447 
448     MachineBasicBlock *MBB = MBBs[i];
449     for (MachineInstr &Term : MBB->terminators()) {
450       if (Term.isConditionalBranch()) {
451         assert(CondJmp == nullptr);
452         CondJmp = &Term;
453       } else if (Term.isUnconditionalBranch()) {
454         assert(UncondJmp == nullptr);
455         UncondJmp = &Term;
456       }
457     }
458 
459     // (1). no terminator, simple follow through.
460     if (!CondJmp && !UncondJmp)
461       continue;
462 
463     MachineBasicBlock *CondTargetBB, *JmpBB;
464     CurrNumInsns = SoFarNumInsns[MBB];
465 
466     // (2). jmp to another bb.
467     if (!CondJmp && UncondJmp) {
468       JmpBB = UncondJmp->getOperand(0).getMBB();
469       if (in16BitRange(SoFarNumInsns[JmpBB] - JmpBB->size() - CurrNumInsns))
470         continue;
471 
472       // replace this insn as a JMPL.
473       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
474       UncondJmp->eraseFromParent();
475       Changed = true;
476       continue;
477     }
478 
479     const BasicBlock *TermBB = MBB->getBasicBlock();
480     int Dist;
481 
482     // (3). conditional jmp to another bb or follow through.
483     if (!UncondJmp) {
484       CondTargetBB = CondJmp->getOperand(2).getMBB();
485       MachineBasicBlock *FollowBB = FollowThroughBB[MBB];
486       Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
487       if (in16BitRange(Dist))
488         continue;
489 
490       // We have
491       //   B2: ...
492       //       if (cond) goto B5
493       //   B3: ...
494       // where B2 -> B5 is beyond 16bit range.
495       //
496       // We do not have 32bit cond jmp insn. So we try to do
497       // the following.
498       //   B2:     ...
499       //           if (cond) goto New_B1
500       //   New_B0  goto B3
501       //   New_B1: gotol B5
502       //   B3: ...
503       // Basically two new basic blocks are created.
504       MachineBasicBlock *New_B0 = MF->CreateMachineBasicBlock(TermBB);
505       MachineBasicBlock *New_B1 = MF->CreateMachineBasicBlock(TermBB);
506 
507       // Insert New_B0 and New_B1 into function block list.
508       MachineFunction::iterator MBB_I  = ++MBB->getIterator();
509       MF->insert(MBB_I, New_B0);
510       MF->insert(MBB_I, New_B1);
511 
512       // replace B2 cond jump
513       if (CondJmp->getOperand(1).isReg())
514         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
515             .addReg(CondJmp->getOperand(0).getReg())
516             .addReg(CondJmp->getOperand(1).getReg())
517             .addMBB(New_B1);
518       else
519         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
520             .addReg(CondJmp->getOperand(0).getReg())
521             .addImm(CondJmp->getOperand(1).getImm())
522             .addMBB(New_B1);
523 
524       // it is possible that CondTargetBB and FollowBB are the same. But the
525       // above Dist checking should already filtered this case.
526       MBB->removeSuccessor(CondTargetBB);
527       MBB->removeSuccessor(FollowBB);
528       MBB->addSuccessor(New_B0);
529       MBB->addSuccessor(New_B1);
530 
531       // Populate insns in New_B0 and New_B1.
532       BuildMI(New_B0, CondJmp->getDebugLoc(), TII->get(BPF::JMP)).addMBB(FollowBB);
533       BuildMI(New_B1, CondJmp->getDebugLoc(), TII->get(BPF::JMPL))
534           .addMBB(CondTargetBB);
535 
536       New_B0->addSuccessor(FollowBB);
537       New_B1->addSuccessor(CondTargetBB);
538       CondJmp->eraseFromParent();
539       Changed = true;
540       continue;
541     }
542 
543     //  (4). conditional jmp followed by an unconditional jmp.
544     CondTargetBB = CondJmp->getOperand(2).getMBB();
545     JmpBB = UncondJmp->getOperand(0).getMBB();
546 
547     // We have
548     //   B2: ...
549     //       if (cond) goto B5
550     //       JMP B7
551     //   B3: ...
552     //
553     // If only B2->B5 is out of 16bit range, we can do
554     //   B2: ...
555     //       if (cond) goto new_B
556     //       JMP B7
557     //   New_B: gotol B5
558     //   B3: ...
559     //
560     // If only 'JMP B7' is out of 16bit range, we can replace
561     // 'JMP B7' with 'JMPL B7'.
562     //
563     // If both B2->B5 and 'JMP B7' is out of range, just do
564     // both the above transformations.
565     Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
566     if (!in16BitRange(Dist)) {
567       MachineBasicBlock *New_B = MF->CreateMachineBasicBlock(TermBB);
568 
569       // Insert New_B0 into function block list.
570       MF->insert(++MBB->getIterator(), New_B);
571 
572       // replace B2 cond jump
573       if (CondJmp->getOperand(1).isReg())
574         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
575             .addReg(CondJmp->getOperand(0).getReg())
576             .addReg(CondJmp->getOperand(1).getReg())
577             .addMBB(New_B);
578       else
579         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
580             .addReg(CondJmp->getOperand(0).getReg())
581             .addImm(CondJmp->getOperand(1).getImm())
582             .addMBB(New_B);
583 
584       if (CondTargetBB != JmpBB)
585         MBB->removeSuccessor(CondTargetBB);
586       MBB->addSuccessor(New_B);
587 
588       // Populate insn in New_B.
589       BuildMI(New_B, CondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(CondTargetBB);
590 
591       New_B->addSuccessor(CondTargetBB);
592       CondJmp->eraseFromParent();
593       Changed = true;
594     }
595 
596     if (!in16BitRange(SoFarNumInsns[JmpBB] - CurrNumInsns)) {
597       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
598       UncondJmp->eraseFromParent();
599       Changed = true;
600     }
601   }
602 
603   return Changed;
604 }
605 
606 static const unsigned CallerSavedRegs[] = {BPF::R0, BPF::R1, BPF::R2,
607                                            BPF::R3, BPF::R4, BPF::R5};
608 
609 struct BPFFastCall {
610   MachineInstr *MI;
611   unsigned LiveCallerSavedRegs;
612 };
613 
614 static void collectBPFFastCalls(const TargetRegisterInfo *TRI,
615                                 LivePhysRegs &LiveRegs, MachineBasicBlock &BB,
616                                 SmallVectorImpl<BPFFastCall> &Calls) {
617   LiveRegs.init(*TRI);
618   LiveRegs.addLiveOuts(BB);
619   Calls.clear();
620   for (MachineInstr &MI : llvm::reverse(BB)) {
621     if (MI.isCall()) {
622       unsigned LiveCallerSavedRegs = 0;
623       for (MCRegister R : CallerSavedRegs) {
624         bool DoSpillFill = false;
625         for (MCPhysReg SR : TRI->subregs(R))
626           DoSpillFill |= !MI.definesRegister(SR, TRI) && LiveRegs.contains(SR);
627         if (!DoSpillFill)
628           continue;
629         LiveCallerSavedRegs |= 1 << R;
630       }
631       if (LiveCallerSavedRegs)
632         Calls.push_back({&MI, LiveCallerSavedRegs});
633     }
634     LiveRegs.stepBackward(MI);
635   }
636 }
637 
638 static int64_t computeMinFixedObjOffset(MachineFrameInfo &MFI,
639                                         unsigned SlotSize) {
640   int64_t MinFixedObjOffset = 0;
641   // Same logic as in X86FrameLowering::adjustFrameForMsvcCxxEh()
642   for (int I = MFI.getObjectIndexBegin(); I < MFI.getObjectIndexEnd(); ++I) {
643     if (MFI.isDeadObjectIndex(I))
644       continue;
645     MinFixedObjOffset = std::min(MinFixedObjOffset, MFI.getObjectOffset(I));
646   }
647   MinFixedObjOffset -=
648       (SlotSize + MinFixedObjOffset % SlotSize) & (SlotSize - 1);
649   return MinFixedObjOffset;
650 }
651 
652 bool BPFMIPreEmitPeephole::insertMissingCallerSavedSpills() {
653   MachineFrameInfo &MFI = MF->getFrameInfo();
654   SmallVector<BPFFastCall, 8> Calls;
655   LivePhysRegs LiveRegs;
656   const unsigned SlotSize = 8;
657   int64_t MinFixedObjOffset = computeMinFixedObjOffset(MFI, SlotSize);
658   bool Changed = false;
659   for (MachineBasicBlock &BB : *MF) {
660     collectBPFFastCalls(TRI, LiveRegs, BB, Calls);
661     Changed |= !Calls.empty();
662     for (BPFFastCall &Call : Calls) {
663       int64_t CurOffset = MinFixedObjOffset;
664       for (MCRegister Reg : CallerSavedRegs) {
665         if (((1 << Reg) & Call.LiveCallerSavedRegs) == 0)
666           continue;
667         // Allocate stack object
668         CurOffset -= SlotSize;
669         MFI.CreateFixedSpillStackObject(SlotSize, CurOffset);
670         // Generate spill
671         BuildMI(BB, Call.MI->getIterator(), Call.MI->getDebugLoc(),
672                 TII->get(BPF::STD))
673             .addReg(Reg, RegState::Kill)
674             .addReg(BPF::R10)
675             .addImm(CurOffset);
676         // Generate fill
677         BuildMI(BB, ++Call.MI->getIterator(), Call.MI->getDebugLoc(),
678                 TII->get(BPF::LDD))
679             .addReg(Reg, RegState::Define)
680             .addReg(BPF::R10)
681             .addImm(CurOffset);
682       }
683     }
684   }
685   return Changed;
686 }
687 
688 bool BPFMIPreEmitPeephole::removeMayGotoZero() {
689   bool Changed = false;
690   MachineBasicBlock *Prev_MBB, *Curr_MBB = nullptr;
691 
692   for (MachineBasicBlock &MBB : make_early_inc_range(reverse(*MF))) {
693     Prev_MBB = Curr_MBB;
694     Curr_MBB = &MBB;
695     if (Prev_MBB == nullptr || Curr_MBB->empty())
696       continue;
697 
698     MachineInstr &MI = Curr_MBB->back();
699     if (MI.getOpcode() != TargetOpcode::INLINEASM_BR)
700       continue;
701 
702     const char *AsmStr = MI.getOperand(0).getSymbolName();
703     SmallVector<StringRef, 4> AsmPieces;
704     SplitString(AsmStr, AsmPieces, ";\n");
705 
706     // Do not support multiple insns in one inline asm.
707     if (AsmPieces.size() != 1)
708       continue;
709 
710     // The asm insn must be a may_goto insn.
711     SmallVector<StringRef, 4> AsmOpPieces;
712     SplitString(AsmPieces[0], AsmOpPieces, " ");
713     if (AsmOpPieces.size() != 2 || AsmOpPieces[0] != "may_goto")
714       continue;
715     // Enforce the format of 'may_goto <label>'.
716     if (AsmOpPieces[1] != "${0:l}" && AsmOpPieces[1] != "$0")
717       continue;
718 
719     // Get the may_goto branch target.
720     MachineOperand &MO = MI.getOperand(InlineAsm::MIOp_FirstOperand + 1);
721     if (!MO.isMBB() || MO.getMBB() != Prev_MBB)
722       continue;
723 
724     Changed = true;
725     if (Curr_MBB->begin() == MI) {
726       // Single 'may_goto' insn in the same basic block.
727       Curr_MBB->removeSuccessor(Prev_MBB);
728       for (MachineBasicBlock *Pred : Curr_MBB->predecessors())
729         Pred->replaceSuccessor(Curr_MBB, Prev_MBB);
730       Curr_MBB->eraseFromParent();
731       Curr_MBB = Prev_MBB;
732     } else {
733       // Remove 'may_goto' insn.
734       MI.eraseFromParent();
735     }
736   }
737 
738   return Changed;
739 }
740 
741 } // end default namespace
742 
743 INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
744                 "BPF PreEmit Peephole Optimization", false, false)
745 
746 char BPFMIPreEmitPeephole::ID = 0;
747 FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
748 {
749   return new BPFMIPreEmitPeephole();
750 }
751