xref: /llvm-project/bolt/lib/Target/RISCV/RISCVMCPlusBuilder.cpp (revision 3023b15fb1ec00dbe6a1cb630236125f500978ef)
1 //===- bolt/Target/RISCV/RISCVMCPlusBuilder.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides RISCV-specific MCPlus builder.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "MCTargetDesc/RISCVMCExpr.h"
14 #include "MCTargetDesc/RISCVMCTargetDesc.h"
15 #include "bolt/Core/MCPlusBuilder.h"
16 #include "llvm/BinaryFormat/ELF.h"
17 #include "llvm/MC/MCInst.h"
18 #include "llvm/MC/MCSubtargetInfo.h"
19 #include "llvm/Support/ErrorHandling.h"
20 
21 #define DEBUG_TYPE "mcplus"
22 
23 using namespace llvm;
24 using namespace bolt;
25 
26 namespace {
27 
28 class RISCVMCPlusBuilder : public MCPlusBuilder {
29 public:
30   using MCPlusBuilder::MCPlusBuilder;
31 
32   bool equals(const MCTargetExpr &A, const MCTargetExpr &B,
33               CompFuncTy Comp) const override {
34     const auto &RISCVExprA = cast<RISCVMCExpr>(A);
35     const auto &RISCVExprB = cast<RISCVMCExpr>(B);
36     if (RISCVExprA.getKind() != RISCVExprB.getKind())
37       return false;
38 
39     return MCPlusBuilder::equals(*RISCVExprA.getSubExpr(),
40                                  *RISCVExprB.getSubExpr(), Comp);
41   }
42 
43   void getCalleeSavedRegs(BitVector &Regs) const override {
44     Regs |= getAliases(RISCV::X2);
45     Regs |= getAliases(RISCV::X8);
46     Regs |= getAliases(RISCV::X9);
47     Regs |= getAliases(RISCV::X18);
48     Regs |= getAliases(RISCV::X19);
49     Regs |= getAliases(RISCV::X20);
50     Regs |= getAliases(RISCV::X21);
51     Regs |= getAliases(RISCV::X22);
52     Regs |= getAliases(RISCV::X23);
53     Regs |= getAliases(RISCV::X24);
54     Regs |= getAliases(RISCV::X25);
55     Regs |= getAliases(RISCV::X26);
56     Regs |= getAliases(RISCV::X27);
57   }
58 
59   bool shouldRecordCodeRelocation(uint64_t RelType) const override {
60     switch (RelType) {
61     case ELF::R_RISCV_JAL:
62     case ELF::R_RISCV_CALL:
63     case ELF::R_RISCV_CALL_PLT:
64     case ELF::R_RISCV_BRANCH:
65     case ELF::R_RISCV_RVC_BRANCH:
66     case ELF::R_RISCV_RVC_JUMP:
67     case ELF::R_RISCV_GOT_HI20:
68     case ELF::R_RISCV_PCREL_HI20:
69     case ELF::R_RISCV_PCREL_LO12_I:
70     case ELF::R_RISCV_PCREL_LO12_S:
71     case ELF::R_RISCV_HI20:
72     case ELF::R_RISCV_LO12_I:
73     case ELF::R_RISCV_LO12_S:
74     case ELF::R_RISCV_TLS_GOT_HI20:
75       return true;
76     default:
77       llvm_unreachable("Unexpected RISCV relocation type in code");
78     }
79   }
80 
81   bool isNop(const MCInst &Inst) const {
82     return Inst.getOpcode() == RISCV::ADDI &&
83            Inst.getOperand(0).getReg() == RISCV::X0 &&
84            Inst.getOperand(1).getReg() == RISCV::X0 &&
85            Inst.getOperand(2).getImm() == 0;
86   }
87 
88   bool isCNop(const MCInst &Inst) const {
89     return Inst.getOpcode() == RISCV::C_NOP;
90   }
91 
92   bool isNoop(const MCInst &Inst) const override {
93     return isNop(Inst) || isCNop(Inst);
94   }
95 
96   bool isPseudo(const MCInst &Inst) const override {
97     switch (Inst.getOpcode()) {
98     default:
99       return MCPlusBuilder::isPseudo(Inst);
100     case RISCV::PseudoCALL:
101     case RISCV::PseudoTAIL:
102       return false;
103     }
104   }
105 
106   bool isIndirectCall(const MCInst &Inst) const override {
107     if (!isCall(Inst))
108       return false;
109 
110     switch (Inst.getOpcode()) {
111     default:
112       return false;
113     case RISCV::JALR:
114     case RISCV::C_JALR:
115     case RISCV::C_JR:
116       return true;
117     }
118   }
119 
120   bool hasPCRelOperand(const MCInst &Inst) const override {
121     switch (Inst.getOpcode()) {
122     default:
123       return false;
124     case RISCV::JAL:
125     case RISCV::AUIPC:
126       return true;
127     }
128   }
129 
130   unsigned getInvertedBranchOpcode(unsigned Opcode) const {
131     switch (Opcode) {
132     default:
133       llvm_unreachable("Failed to invert branch opcode");
134       return Opcode;
135     case RISCV::BEQ:
136       return RISCV::BNE;
137     case RISCV::BNE:
138       return RISCV::BEQ;
139     case RISCV::BLT:
140       return RISCV::BGE;
141     case RISCV::BGE:
142       return RISCV::BLT;
143     case RISCV::BLTU:
144       return RISCV::BGEU;
145     case RISCV::BGEU:
146       return RISCV::BLTU;
147     case RISCV::C_BEQZ:
148       return RISCV::C_BNEZ;
149     case RISCV::C_BNEZ:
150       return RISCV::C_BEQZ;
151     }
152   }
153 
154   void reverseBranchCondition(MCInst &Inst, const MCSymbol *TBB,
155                               MCContext *Ctx) const override {
156     auto Opcode = getInvertedBranchOpcode(Inst.getOpcode());
157     Inst.setOpcode(Opcode);
158     replaceBranchTarget(Inst, TBB, Ctx);
159   }
160 
161   void replaceBranchTarget(MCInst &Inst, const MCSymbol *TBB,
162                            MCContext *Ctx) const override {
163     assert((isCall(Inst) || isBranch(Inst)) && !isIndirectBranch(Inst) &&
164            "Invalid instruction");
165 
166     unsigned SymOpIndex;
167     auto Result = getSymbolRefOperandNum(Inst, SymOpIndex);
168     (void)Result;
169     assert(Result && "unimplemented branch");
170 
171     Inst.getOperand(SymOpIndex) = MCOperand::createExpr(
172         MCSymbolRefExpr::create(TBB, MCSymbolRefExpr::VK_None, *Ctx));
173   }
174 
175   IndirectBranchType analyzeIndirectBranch(
176       MCInst &Instruction, InstructionIterator Begin, InstructionIterator End,
177       const unsigned PtrSize, MCInst *&MemLocInstr, unsigned &BaseRegNum,
178       unsigned &IndexRegNum, int64_t &DispValue, const MCExpr *&DispExpr,
179       MCInst *&PCRelBaseOut, MCInst *&FixedEntryLoadInst) const override {
180     MemLocInstr = nullptr;
181     BaseRegNum = 0;
182     IndexRegNum = 0;
183     DispValue = 0;
184     DispExpr = nullptr;
185     PCRelBaseOut = nullptr;
186     FixedEntryLoadInst = nullptr;
187 
188     // Check for the following long tail call sequence:
189     // 1: auipc xi, %pcrel_hi(sym)
190     // jalr zero, %pcrel_lo(1b)(xi)
191     if (Instruction.getOpcode() == RISCV::JALR && Begin != End) {
192       MCInst &PrevInst = *std::prev(End);
193       if (isRISCVCall(PrevInst, Instruction) &&
194           Instruction.getOperand(0).getReg() == RISCV::X0)
195         return IndirectBranchType::POSSIBLE_TAIL_CALL;
196     }
197 
198     return IndirectBranchType::UNKNOWN;
199   }
200 
201   bool convertJmpToTailCall(MCInst &Inst) override {
202     if (isTailCall(Inst))
203       return false;
204 
205     switch (Inst.getOpcode()) {
206     default:
207       llvm_unreachable("unsupported tail call opcode");
208     case RISCV::JAL:
209     case RISCV::JALR:
210     case RISCV::C_J:
211     case RISCV::C_JR:
212       break;
213     }
214 
215     setTailCall(Inst);
216     return true;
217   }
218 
219   void createReturn(MCInst &Inst) const override {
220     // TODO "c.jr ra" when RVC is enabled
221     Inst.setOpcode(RISCV::JALR);
222     Inst.clear();
223     Inst.addOperand(MCOperand::createReg(RISCV::X0));
224     Inst.addOperand(MCOperand::createReg(RISCV::X1));
225     Inst.addOperand(MCOperand::createImm(0));
226   }
227 
228   void createUncondBranch(MCInst &Inst, const MCSymbol *TBB,
229                           MCContext *Ctx) const override {
230     Inst.setOpcode(RISCV::JAL);
231     Inst.clear();
232     Inst.addOperand(MCOperand::createReg(RISCV::X0));
233     Inst.addOperand(MCOperand::createExpr(
234         MCSymbolRefExpr::create(TBB, MCSymbolRefExpr::VK_None, *Ctx)));
235   }
236 
237   StringRef getTrapFillValue() const override {
238     return StringRef("\0\0\0\0", 4);
239   }
240 
241   void createCall(unsigned Opcode, MCInst &Inst, const MCSymbol *Target,
242                   MCContext *Ctx) {
243     Inst.setOpcode(Opcode);
244     Inst.clear();
245     Inst.addOperand(MCOperand::createExpr(RISCVMCExpr::create(
246         MCSymbolRefExpr::create(Target, MCSymbolRefExpr::VK_None, *Ctx),
247         RISCVMCExpr::VK_RISCV_CALL, *Ctx)));
248   }
249 
250   void createCall(MCInst &Inst, const MCSymbol *Target,
251                   MCContext *Ctx) override {
252     return createCall(RISCV::PseudoCALL, Inst, Target, Ctx);
253   }
254 
255   void createTailCall(MCInst &Inst, const MCSymbol *Target,
256                       MCContext *Ctx) override {
257     return createCall(RISCV::PseudoTAIL, Inst, Target, Ctx);
258   }
259 
260   bool analyzeBranch(InstructionIterator Begin, InstructionIterator End,
261                      const MCSymbol *&TBB, const MCSymbol *&FBB,
262                      MCInst *&CondBranch,
263                      MCInst *&UncondBranch) const override {
264     auto I = End;
265 
266     while (I != Begin) {
267       --I;
268 
269       // Ignore nops and CFIs
270       if (isPseudo(*I) || isNoop(*I))
271         continue;
272 
273       // Stop when we find the first non-terminator
274       if (!isTerminator(*I) || isTailCall(*I) || !isBranch(*I))
275         break;
276 
277       // Handle unconditional branches.
278       if (isUnconditionalBranch(*I)) {
279         // If any code was seen after this unconditional branch, we've seen
280         // unreachable code. Ignore them.
281         CondBranch = nullptr;
282         UncondBranch = &*I;
283         const MCSymbol *Sym = getTargetSymbol(*I);
284         assert(Sym != nullptr &&
285                "Couldn't extract BB symbol from jump operand");
286         TBB = Sym;
287         continue;
288       }
289 
290       // Handle conditional branches and ignore indirect branches
291       if (isIndirectBranch(*I))
292         return false;
293 
294       if (CondBranch == nullptr) {
295         const MCSymbol *TargetBB = getTargetSymbol(*I);
296         if (TargetBB == nullptr) {
297           // Unrecognized branch target
298           return false;
299         }
300         FBB = TBB;
301         TBB = TargetBB;
302         CondBranch = &*I;
303         continue;
304       }
305 
306       llvm_unreachable("multiple conditional branches in one BB");
307     }
308 
309     return true;
310   }
311 
312   bool getSymbolRefOperandNum(const MCInst &Inst, unsigned &OpNum) const {
313     switch (Inst.getOpcode()) {
314     default:
315       return false;
316     case RISCV::C_J:
317       OpNum = 0;
318       return true;
319     case RISCV::AUIPC:
320     case RISCV::JAL:
321     case RISCV::C_BEQZ:
322     case RISCV::C_BNEZ:
323       OpNum = 1;
324       return true;
325     case RISCV::BEQ:
326     case RISCV::BGE:
327     case RISCV::BGEU:
328     case RISCV::BNE:
329     case RISCV::BLT:
330     case RISCV::BLTU:
331       OpNum = 2;
332       return true;
333     }
334   }
335 
336   const MCSymbol *getTargetSymbol(const MCExpr *Expr) const override {
337     auto *RISCVExpr = dyn_cast<RISCVMCExpr>(Expr);
338     if (RISCVExpr && RISCVExpr->getSubExpr())
339       return getTargetSymbol(RISCVExpr->getSubExpr());
340 
341     auto *BinExpr = dyn_cast<MCBinaryExpr>(Expr);
342     if (BinExpr)
343       return getTargetSymbol(BinExpr->getLHS());
344 
345     auto *SymExpr = dyn_cast<MCSymbolRefExpr>(Expr);
346     if (SymExpr && SymExpr->getKind() == MCSymbolRefExpr::VK_None)
347       return &SymExpr->getSymbol();
348 
349     return nullptr;
350   }
351 
352   const MCSymbol *getTargetSymbol(const MCInst &Inst,
353                                   unsigned OpNum = 0) const override {
354     if (!OpNum && !getSymbolRefOperandNum(Inst, OpNum))
355       return nullptr;
356 
357     const MCOperand &Op = Inst.getOperand(OpNum);
358     if (!Op.isExpr())
359       return nullptr;
360 
361     return getTargetSymbol(Op.getExpr());
362   }
363 
364   bool lowerTailCall(MCInst &Inst) override {
365     removeAnnotation(Inst, MCPlus::MCAnnotation::kTailCall);
366     if (getConditionalTailCall(Inst))
367       unsetConditionalTailCall(Inst);
368     return true;
369   }
370 
371   uint64_t analyzePLTEntry(MCInst &Instruction, InstructionIterator Begin,
372                            InstructionIterator End,
373                            uint64_t BeginPC) const override {
374     auto I = Begin;
375 
376     assert(I != End);
377     auto &AUIPC = *I++;
378     assert(AUIPC.getOpcode() == RISCV::AUIPC);
379     assert(AUIPC.getOperand(0).getReg() == RISCV::X28);
380 
381     assert(I != End);
382     auto &LD = *I++;
383     assert(LD.getOpcode() == RISCV::LD);
384     assert(LD.getOperand(0).getReg() == RISCV::X28);
385     assert(LD.getOperand(1).getReg() == RISCV::X28);
386 
387     assert(I != End);
388     auto &JALR = *I++;
389     (void)JALR;
390     assert(JALR.getOpcode() == RISCV::JALR);
391     assert(JALR.getOperand(0).getReg() == RISCV::X6);
392     assert(JALR.getOperand(1).getReg() == RISCV::X28);
393 
394     assert(I != End);
395     auto &NOP = *I++;
396     (void)NOP;
397     assert(isNoop(NOP));
398 
399     assert(I == End);
400 
401     auto AUIPCOffset = AUIPC.getOperand(1).getImm() << 12;
402     auto LDOffset = LD.getOperand(2).getImm();
403     return BeginPC + AUIPCOffset + LDOffset;
404   }
405 
406   bool replaceImmWithSymbolRef(MCInst &Inst, const MCSymbol *Symbol,
407                                int64_t Addend, MCContext *Ctx, int64_t &Value,
408                                uint64_t RelType) const override {
409     unsigned ImmOpNo = -1U;
410 
411     for (unsigned Index = 0; Index < MCPlus::getNumPrimeOperands(Inst);
412          ++Index) {
413       if (Inst.getOperand(Index).isImm()) {
414         ImmOpNo = Index;
415         break;
416       }
417     }
418 
419     if (ImmOpNo == -1U)
420       return false;
421 
422     Value = Inst.getOperand(ImmOpNo).getImm();
423     setOperandToSymbolRef(Inst, ImmOpNo, Symbol, Addend, Ctx, RelType);
424     return true;
425   }
426 
427   const MCExpr *getTargetExprFor(MCInst &Inst, const MCExpr *Expr,
428                                  MCContext &Ctx,
429                                  uint64_t RelType) const override {
430     switch (RelType) {
431     default:
432       return Expr;
433     case ELF::R_RISCV_GOT_HI20:
434     case ELF::R_RISCV_TLS_GOT_HI20:
435       // The GOT is reused so no need to create GOT relocations
436     case ELF::R_RISCV_PCREL_HI20:
437       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_PCREL_HI, Ctx);
438     case ELF::R_RISCV_PCREL_LO12_I:
439     case ELF::R_RISCV_PCREL_LO12_S:
440       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_PCREL_LO, Ctx);
441     case ELF::R_RISCV_HI20:
442       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_HI, Ctx);
443     case ELF::R_RISCV_LO12_I:
444     case ELF::R_RISCV_LO12_S:
445       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_LO, Ctx);
446     case ELF::R_RISCV_CALL:
447       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_CALL, Ctx);
448     case ELF::R_RISCV_CALL_PLT:
449       return RISCVMCExpr::create(Expr, RISCVMCExpr::VK_RISCV_CALL_PLT, Ctx);
450     }
451   }
452 
453   bool evaluateMemOperandTarget(const MCInst &Inst, uint64_t &Target,
454                                 uint64_t Address,
455                                 uint64_t Size) const override {
456     return false;
457   }
458 
459   bool isCallAuipc(const MCInst &Inst) const {
460     if (Inst.getOpcode() != RISCV::AUIPC)
461       return false;
462 
463     const auto &ImmOp = Inst.getOperand(1);
464     if (!ImmOp.isExpr())
465       return false;
466 
467     const auto *ImmExpr = ImmOp.getExpr();
468     if (!isa<RISCVMCExpr>(ImmExpr))
469       return false;
470 
471     switch (cast<RISCVMCExpr>(ImmExpr)->getKind()) {
472     default:
473       return false;
474     case RISCVMCExpr::VK_RISCV_CALL:
475     case RISCVMCExpr::VK_RISCV_CALL_PLT:
476       return true;
477     }
478   }
479 
480   bool isRISCVCall(const MCInst &First, const MCInst &Second) const override {
481     if (!isCallAuipc(First))
482       return false;
483 
484     assert(Second.getOpcode() == RISCV::JALR);
485     return true;
486   }
487 
488   uint16_t getMinFunctionAlignment() const override {
489     if (STI->hasFeature(RISCV::FeatureStdExtC) ||
490         STI->hasFeature(RISCV::FeatureStdExtZca))
491       return 2;
492     return 4;
493   }
494 };
495 
496 } // end anonymous namespace
497 
498 namespace llvm {
499 namespace bolt {
500 
501 MCPlusBuilder *createRISCVMCPlusBuilder(const MCInstrAnalysis *Analysis,
502                                         const MCInstrInfo *Info,
503                                         const MCRegisterInfo *RegInfo,
504                                         const MCSubtargetInfo *STI) {
505   return new RISCVMCPlusBuilder(Analysis, Info, RegInfo, STI);
506 }
507 
508 } // namespace bolt
509 } // namespace llvm
510