xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
1*06c3fb27SDimitry Andric //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2*06c3fb27SDimitry Andric //
3*06c3fb27SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*06c3fb27SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*06c3fb27SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*06c3fb27SDimitry Andric //
7*06c3fb27SDimitry Andric //===---------------------------------------------------------------------===//
8*06c3fb27SDimitry Andric //
9*06c3fb27SDimitry Andric // This pass does some optimizations for *W instructions at the MI level.
10*06c3fb27SDimitry Andric //
11*06c3fb27SDimitry Andric // First it removes unneeded sext.w instructions. Either because the sign
12*06c3fb27SDimitry Andric // extended bits aren't consumed or because the input was already sign extended
13*06c3fb27SDimitry Andric // by an earlier instruction.
14*06c3fb27SDimitry Andric //
15*06c3fb27SDimitry Andric // Then it removes the -w suffix from each addiw and slliw instructions
16*06c3fb27SDimitry Andric // whenever all users are dependent only on the lower word of the result of the
17*06c3fb27SDimitry Andric // instruction. We do this only for addiw, slliw, and mulw because the -w forms
18*06c3fb27SDimitry Andric // are less compressible.
19*06c3fb27SDimitry Andric //
20*06c3fb27SDimitry Andric //===---------------------------------------------------------------------===//
21*06c3fb27SDimitry Andric 
22*06c3fb27SDimitry Andric #include "RISCV.h"
23*06c3fb27SDimitry Andric #include "RISCVMachineFunctionInfo.h"
24*06c3fb27SDimitry Andric #include "RISCVSubtarget.h"
25*06c3fb27SDimitry Andric #include "llvm/ADT/Statistic.h"
26*06c3fb27SDimitry Andric #include "llvm/CodeGen/MachineFunctionPass.h"
27*06c3fb27SDimitry Andric #include "llvm/CodeGen/TargetInstrInfo.h"
28*06c3fb27SDimitry Andric 
29*06c3fb27SDimitry Andric using namespace llvm;
30*06c3fb27SDimitry Andric 
31*06c3fb27SDimitry Andric #define DEBUG_TYPE "riscv-opt-w-instrs"
32*06c3fb27SDimitry Andric #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
33*06c3fb27SDimitry Andric 
34*06c3fb27SDimitry Andric STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
35*06c3fb27SDimitry Andric STATISTIC(NumTransformedToWInstrs,
36*06c3fb27SDimitry Andric           "Number of instructions transformed to W-ops");
37*06c3fb27SDimitry Andric 
38*06c3fb27SDimitry Andric static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
39*06c3fb27SDimitry Andric                                          cl::desc("Disable removal of sext.w"),
40*06c3fb27SDimitry Andric                                          cl::init(false), cl::Hidden);
41*06c3fb27SDimitry Andric static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
42*06c3fb27SDimitry Andric                                          cl::desc("Disable strip W suffix"),
43*06c3fb27SDimitry Andric                                          cl::init(false), cl::Hidden);
44*06c3fb27SDimitry Andric 
45*06c3fb27SDimitry Andric namespace {
46*06c3fb27SDimitry Andric 
47*06c3fb27SDimitry Andric class RISCVOptWInstrs : public MachineFunctionPass {
48*06c3fb27SDimitry Andric public:
49*06c3fb27SDimitry Andric   static char ID;
50*06c3fb27SDimitry Andric 
51*06c3fb27SDimitry Andric   RISCVOptWInstrs() : MachineFunctionPass(ID) {
52*06c3fb27SDimitry Andric     initializeRISCVOptWInstrsPass(*PassRegistry::getPassRegistry());
53*06c3fb27SDimitry Andric   }
54*06c3fb27SDimitry Andric 
55*06c3fb27SDimitry Andric   bool runOnMachineFunction(MachineFunction &MF) override;
56*06c3fb27SDimitry Andric   bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
57*06c3fb27SDimitry Andric                          const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
58*06c3fb27SDimitry Andric   bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
59*06c3fb27SDimitry Andric                       const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
60*06c3fb27SDimitry Andric 
61*06c3fb27SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
62*06c3fb27SDimitry Andric     AU.setPreservesCFG();
63*06c3fb27SDimitry Andric     MachineFunctionPass::getAnalysisUsage(AU);
64*06c3fb27SDimitry Andric   }
65*06c3fb27SDimitry Andric 
66*06c3fb27SDimitry Andric   StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
67*06c3fb27SDimitry Andric };
68*06c3fb27SDimitry Andric 
69*06c3fb27SDimitry Andric } // end anonymous namespace
70*06c3fb27SDimitry Andric 
71*06c3fb27SDimitry Andric char RISCVOptWInstrs::ID = 0;
72*06c3fb27SDimitry Andric INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
73*06c3fb27SDimitry Andric                 false)
74*06c3fb27SDimitry Andric 
75*06c3fb27SDimitry Andric FunctionPass *llvm::createRISCVOptWInstrsPass() {
76*06c3fb27SDimitry Andric   return new RISCVOptWInstrs();
77*06c3fb27SDimitry Andric }
78*06c3fb27SDimitry Andric 
79*06c3fb27SDimitry Andric // Checks if all users only demand the lower \p OrigBits of the original
80*06c3fb27SDimitry Andric // instruction's result.
81*06c3fb27SDimitry Andric // TODO: handle multiple interdependent transformations
82*06c3fb27SDimitry Andric static bool hasAllNBitUsers(const MachineInstr &OrigMI,
83*06c3fb27SDimitry Andric                             const RISCVSubtarget &ST,
84*06c3fb27SDimitry Andric                             const MachineRegisterInfo &MRI, unsigned OrigBits) {
85*06c3fb27SDimitry Andric 
86*06c3fb27SDimitry Andric   SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
87*06c3fb27SDimitry Andric   SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
88*06c3fb27SDimitry Andric 
89*06c3fb27SDimitry Andric   Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
90*06c3fb27SDimitry Andric 
91*06c3fb27SDimitry Andric   while (!Worklist.empty()) {
92*06c3fb27SDimitry Andric     auto P = Worklist.pop_back_val();
93*06c3fb27SDimitry Andric     const MachineInstr *MI = P.first;
94*06c3fb27SDimitry Andric     unsigned Bits = P.second;
95*06c3fb27SDimitry Andric 
96*06c3fb27SDimitry Andric     if (!Visited.insert(P).second)
97*06c3fb27SDimitry Andric       continue;
98*06c3fb27SDimitry Andric 
99*06c3fb27SDimitry Andric     // Only handle instructions with one def.
100*06c3fb27SDimitry Andric     if (MI->getNumExplicitDefs() != 1)
101*06c3fb27SDimitry Andric       return false;
102*06c3fb27SDimitry Andric 
103*06c3fb27SDimitry Andric     for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
104*06c3fb27SDimitry Andric       const MachineInstr *UserMI = UserOp.getParent();
105*06c3fb27SDimitry Andric       unsigned OpIdx = UserOp.getOperandNo();
106*06c3fb27SDimitry Andric 
107*06c3fb27SDimitry Andric       switch (UserMI->getOpcode()) {
108*06c3fb27SDimitry Andric       default:
109*06c3fb27SDimitry Andric         return false;
110*06c3fb27SDimitry Andric 
111*06c3fb27SDimitry Andric       case RISCV::ADDIW:
112*06c3fb27SDimitry Andric       case RISCV::ADDW:
113*06c3fb27SDimitry Andric       case RISCV::DIVUW:
114*06c3fb27SDimitry Andric       case RISCV::DIVW:
115*06c3fb27SDimitry Andric       case RISCV::MULW:
116*06c3fb27SDimitry Andric       case RISCV::REMUW:
117*06c3fb27SDimitry Andric       case RISCV::REMW:
118*06c3fb27SDimitry Andric       case RISCV::SLLIW:
119*06c3fb27SDimitry Andric       case RISCV::SLLW:
120*06c3fb27SDimitry Andric       case RISCV::SRAIW:
121*06c3fb27SDimitry Andric       case RISCV::SRAW:
122*06c3fb27SDimitry Andric       case RISCV::SRLIW:
123*06c3fb27SDimitry Andric       case RISCV::SRLW:
124*06c3fb27SDimitry Andric       case RISCV::SUBW:
125*06c3fb27SDimitry Andric       case RISCV::ROLW:
126*06c3fb27SDimitry Andric       case RISCV::RORW:
127*06c3fb27SDimitry Andric       case RISCV::RORIW:
128*06c3fb27SDimitry Andric       case RISCV::CLZW:
129*06c3fb27SDimitry Andric       case RISCV::CTZW:
130*06c3fb27SDimitry Andric       case RISCV::CPOPW:
131*06c3fb27SDimitry Andric       case RISCV::SLLI_UW:
132*06c3fb27SDimitry Andric       case RISCV::FMV_W_X:
133*06c3fb27SDimitry Andric       case RISCV::FCVT_H_W:
134*06c3fb27SDimitry Andric       case RISCV::FCVT_H_WU:
135*06c3fb27SDimitry Andric       case RISCV::FCVT_S_W:
136*06c3fb27SDimitry Andric       case RISCV::FCVT_S_WU:
137*06c3fb27SDimitry Andric       case RISCV::FCVT_D_W:
138*06c3fb27SDimitry Andric       case RISCV::FCVT_D_WU:
139*06c3fb27SDimitry Andric         if (Bits >= 32)
140*06c3fb27SDimitry Andric           break;
141*06c3fb27SDimitry Andric         return false;
142*06c3fb27SDimitry Andric       case RISCV::SEXT_B:
143*06c3fb27SDimitry Andric       case RISCV::PACKH:
144*06c3fb27SDimitry Andric         if (Bits >= 8)
145*06c3fb27SDimitry Andric           break;
146*06c3fb27SDimitry Andric         return false;
147*06c3fb27SDimitry Andric       case RISCV::SEXT_H:
148*06c3fb27SDimitry Andric       case RISCV::FMV_H_X:
149*06c3fb27SDimitry Andric       case RISCV::ZEXT_H_RV32:
150*06c3fb27SDimitry Andric       case RISCV::ZEXT_H_RV64:
151*06c3fb27SDimitry Andric       case RISCV::PACKW:
152*06c3fb27SDimitry Andric         if (Bits >= 16)
153*06c3fb27SDimitry Andric           break;
154*06c3fb27SDimitry Andric         return false;
155*06c3fb27SDimitry Andric 
156*06c3fb27SDimitry Andric       case RISCV::PACK:
157*06c3fb27SDimitry Andric         if (Bits >= (ST.getXLen() / 2))
158*06c3fb27SDimitry Andric           break;
159*06c3fb27SDimitry Andric         return false;
160*06c3fb27SDimitry Andric 
161*06c3fb27SDimitry Andric       case RISCV::SRLI: {
162*06c3fb27SDimitry Andric         // If we are shifting right by less than Bits, and users don't demand
163*06c3fb27SDimitry Andric         // any bits that were shifted into [Bits-1:0], then we can consider this
164*06c3fb27SDimitry Andric         // as an N-Bit user.
165*06c3fb27SDimitry Andric         unsigned ShAmt = UserMI->getOperand(2).getImm();
166*06c3fb27SDimitry Andric         if (Bits > ShAmt) {
167*06c3fb27SDimitry Andric           Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
168*06c3fb27SDimitry Andric           break;
169*06c3fb27SDimitry Andric         }
170*06c3fb27SDimitry Andric         return false;
171*06c3fb27SDimitry Andric       }
172*06c3fb27SDimitry Andric 
173*06c3fb27SDimitry Andric       // these overwrite higher input bits, otherwise the lower word of output
174*06c3fb27SDimitry Andric       // depends only on the lower word of input. So check their uses read W.
175*06c3fb27SDimitry Andric       case RISCV::SLLI:
176*06c3fb27SDimitry Andric         if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm()))
177*06c3fb27SDimitry Andric           break;
178*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
179*06c3fb27SDimitry Andric         break;
180*06c3fb27SDimitry Andric       case RISCV::ANDI: {
181*06c3fb27SDimitry Andric         uint64_t Imm = UserMI->getOperand(2).getImm();
182*06c3fb27SDimitry Andric         if (Bits >= (unsigned)llvm::bit_width(Imm))
183*06c3fb27SDimitry Andric           break;
184*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
185*06c3fb27SDimitry Andric         break;
186*06c3fb27SDimitry Andric       }
187*06c3fb27SDimitry Andric       case RISCV::ORI: {
188*06c3fb27SDimitry Andric         uint64_t Imm = UserMI->getOperand(2).getImm();
189*06c3fb27SDimitry Andric         if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
190*06c3fb27SDimitry Andric           break;
191*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
192*06c3fb27SDimitry Andric         break;
193*06c3fb27SDimitry Andric       }
194*06c3fb27SDimitry Andric 
195*06c3fb27SDimitry Andric       case RISCV::SLL:
196*06c3fb27SDimitry Andric       case RISCV::BSET:
197*06c3fb27SDimitry Andric       case RISCV::BCLR:
198*06c3fb27SDimitry Andric       case RISCV::BINV:
199*06c3fb27SDimitry Andric         // Operand 2 is the shift amount which uses log2(xlen) bits.
200*06c3fb27SDimitry Andric         if (OpIdx == 2) {
201*06c3fb27SDimitry Andric           if (Bits >= Log2_32(ST.getXLen()))
202*06c3fb27SDimitry Andric             break;
203*06c3fb27SDimitry Andric           return false;
204*06c3fb27SDimitry Andric         }
205*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
206*06c3fb27SDimitry Andric         break;
207*06c3fb27SDimitry Andric 
208*06c3fb27SDimitry Andric       case RISCV::SRA:
209*06c3fb27SDimitry Andric       case RISCV::SRL:
210*06c3fb27SDimitry Andric       case RISCV::ROL:
211*06c3fb27SDimitry Andric       case RISCV::ROR:
212*06c3fb27SDimitry Andric         // Operand 2 is the shift amount which uses 6 bits.
213*06c3fb27SDimitry Andric         if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
214*06c3fb27SDimitry Andric           break;
215*06c3fb27SDimitry Andric         return false;
216*06c3fb27SDimitry Andric 
217*06c3fb27SDimitry Andric       case RISCV::ADD_UW:
218*06c3fb27SDimitry Andric       case RISCV::SH1ADD_UW:
219*06c3fb27SDimitry Andric       case RISCV::SH2ADD_UW:
220*06c3fb27SDimitry Andric       case RISCV::SH3ADD_UW:
221*06c3fb27SDimitry Andric         // Operand 1 is implicitly zero extended.
222*06c3fb27SDimitry Andric         if (OpIdx == 1 && Bits >= 32)
223*06c3fb27SDimitry Andric           break;
224*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
225*06c3fb27SDimitry Andric         break;
226*06c3fb27SDimitry Andric 
227*06c3fb27SDimitry Andric       case RISCV::BEXTI:
228*06c3fb27SDimitry Andric         if (UserMI->getOperand(2).getImm() >= Bits)
229*06c3fb27SDimitry Andric           return false;
230*06c3fb27SDimitry Andric         break;
231*06c3fb27SDimitry Andric 
232*06c3fb27SDimitry Andric       case RISCV::SB:
233*06c3fb27SDimitry Andric         // The first argument is the value to store.
234*06c3fb27SDimitry Andric         if (OpIdx == 0 && Bits >= 8)
235*06c3fb27SDimitry Andric           break;
236*06c3fb27SDimitry Andric         return false;
237*06c3fb27SDimitry Andric       case RISCV::SH:
238*06c3fb27SDimitry Andric         // The first argument is the value to store.
239*06c3fb27SDimitry Andric         if (OpIdx == 0 && Bits >= 16)
240*06c3fb27SDimitry Andric           break;
241*06c3fb27SDimitry Andric         return false;
242*06c3fb27SDimitry Andric       case RISCV::SW:
243*06c3fb27SDimitry Andric         // The first argument is the value to store.
244*06c3fb27SDimitry Andric         if (OpIdx == 0 && Bits >= 32)
245*06c3fb27SDimitry Andric           break;
246*06c3fb27SDimitry Andric         return false;
247*06c3fb27SDimitry Andric 
248*06c3fb27SDimitry Andric       // For these, lower word of output in these operations, depends only on
249*06c3fb27SDimitry Andric       // the lower word of input. So, we check all uses only read lower word.
250*06c3fb27SDimitry Andric       case RISCV::COPY:
251*06c3fb27SDimitry Andric       case RISCV::PHI:
252*06c3fb27SDimitry Andric 
253*06c3fb27SDimitry Andric       case RISCV::ADD:
254*06c3fb27SDimitry Andric       case RISCV::ADDI:
255*06c3fb27SDimitry Andric       case RISCV::AND:
256*06c3fb27SDimitry Andric       case RISCV::MUL:
257*06c3fb27SDimitry Andric       case RISCV::OR:
258*06c3fb27SDimitry Andric       case RISCV::SUB:
259*06c3fb27SDimitry Andric       case RISCV::XOR:
260*06c3fb27SDimitry Andric       case RISCV::XORI:
261*06c3fb27SDimitry Andric 
262*06c3fb27SDimitry Andric       case RISCV::ANDN:
263*06c3fb27SDimitry Andric       case RISCV::BREV8:
264*06c3fb27SDimitry Andric       case RISCV::CLMUL:
265*06c3fb27SDimitry Andric       case RISCV::ORC_B:
266*06c3fb27SDimitry Andric       case RISCV::ORN:
267*06c3fb27SDimitry Andric       case RISCV::SH1ADD:
268*06c3fb27SDimitry Andric       case RISCV::SH2ADD:
269*06c3fb27SDimitry Andric       case RISCV::SH3ADD:
270*06c3fb27SDimitry Andric       case RISCV::XNOR:
271*06c3fb27SDimitry Andric       case RISCV::BSETI:
272*06c3fb27SDimitry Andric       case RISCV::BCLRI:
273*06c3fb27SDimitry Andric       case RISCV::BINVI:
274*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
275*06c3fb27SDimitry Andric         break;
276*06c3fb27SDimitry Andric 
277*06c3fb27SDimitry Andric       case RISCV::PseudoCCMOVGPR:
278*06c3fb27SDimitry Andric         // Either operand 4 or operand 5 is returned by this instruction. If
279*06c3fb27SDimitry Andric         // only the lower word of the result is used, then only the lower word
280*06c3fb27SDimitry Andric         // of operand 4 and 5 is used.
281*06c3fb27SDimitry Andric         if (OpIdx != 4 && OpIdx != 5)
282*06c3fb27SDimitry Andric           return false;
283*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
284*06c3fb27SDimitry Andric         break;
285*06c3fb27SDimitry Andric 
286*06c3fb27SDimitry Andric       case RISCV::VT_MASKC:
287*06c3fb27SDimitry Andric       case RISCV::VT_MASKCN:
288*06c3fb27SDimitry Andric         if (OpIdx != 1)
289*06c3fb27SDimitry Andric           return false;
290*06c3fb27SDimitry Andric         Worklist.push_back(std::make_pair(UserMI, Bits));
291*06c3fb27SDimitry Andric         break;
292*06c3fb27SDimitry Andric       }
293*06c3fb27SDimitry Andric     }
294*06c3fb27SDimitry Andric   }
295*06c3fb27SDimitry Andric 
296*06c3fb27SDimitry Andric   return true;
297*06c3fb27SDimitry Andric }
298*06c3fb27SDimitry Andric 
299*06c3fb27SDimitry Andric static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
300*06c3fb27SDimitry Andric                          const MachineRegisterInfo &MRI) {
301*06c3fb27SDimitry Andric   return hasAllNBitUsers(OrigMI, ST, MRI, 32);
302*06c3fb27SDimitry Andric }
303*06c3fb27SDimitry Andric 
304*06c3fb27SDimitry Andric // This function returns true if the machine instruction always outputs a value
305*06c3fb27SDimitry Andric // where bits 63:32 match bit 31.
306*06c3fb27SDimitry Andric static bool isSignExtendingOpW(const MachineInstr &MI,
307*06c3fb27SDimitry Andric                                const MachineRegisterInfo &MRI) {
308*06c3fb27SDimitry Andric   uint64_t TSFlags = MI.getDesc().TSFlags;
309*06c3fb27SDimitry Andric 
310*06c3fb27SDimitry Andric   // Instructions that can be determined from opcode are marked in tablegen.
311*06c3fb27SDimitry Andric   if (TSFlags & RISCVII::IsSignExtendingOpWMask)
312*06c3fb27SDimitry Andric     return true;
313*06c3fb27SDimitry Andric 
314*06c3fb27SDimitry Andric   // Special cases that require checking operands.
315*06c3fb27SDimitry Andric   switch (MI.getOpcode()) {
316*06c3fb27SDimitry Andric   // shifting right sufficiently makes the value 32-bit sign-extended
317*06c3fb27SDimitry Andric   case RISCV::SRAI:
318*06c3fb27SDimitry Andric     return MI.getOperand(2).getImm() >= 32;
319*06c3fb27SDimitry Andric   case RISCV::SRLI:
320*06c3fb27SDimitry Andric     return MI.getOperand(2).getImm() > 32;
321*06c3fb27SDimitry Andric   // The LI pattern ADDI rd, X0, imm is sign extended.
322*06c3fb27SDimitry Andric   case RISCV::ADDI:
323*06c3fb27SDimitry Andric     return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
324*06c3fb27SDimitry Andric   // An ANDI with an 11 bit immediate will zero bits 63:11.
325*06c3fb27SDimitry Andric   case RISCV::ANDI:
326*06c3fb27SDimitry Andric     return isUInt<11>(MI.getOperand(2).getImm());
327*06c3fb27SDimitry Andric   // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
328*06c3fb27SDimitry Andric   case RISCV::ORI:
329*06c3fb27SDimitry Andric     return !isUInt<11>(MI.getOperand(2).getImm());
330*06c3fb27SDimitry Andric   // Copying from X0 produces zero.
331*06c3fb27SDimitry Andric   case RISCV::COPY:
332*06c3fb27SDimitry Andric     return MI.getOperand(1).getReg() == RISCV::X0;
333*06c3fb27SDimitry Andric   }
334*06c3fb27SDimitry Andric 
335*06c3fb27SDimitry Andric   return false;
336*06c3fb27SDimitry Andric }
337*06c3fb27SDimitry Andric 
338*06c3fb27SDimitry Andric static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
339*06c3fb27SDimitry Andric                             const MachineRegisterInfo &MRI,
340*06c3fb27SDimitry Andric                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
341*06c3fb27SDimitry Andric 
342*06c3fb27SDimitry Andric   SmallPtrSet<const MachineInstr *, 4> Visited;
343*06c3fb27SDimitry Andric   SmallVector<MachineInstr *, 4> Worklist;
344*06c3fb27SDimitry Andric 
345*06c3fb27SDimitry Andric   auto AddRegDefToWorkList = [&](Register SrcReg) {
346*06c3fb27SDimitry Andric     if (!SrcReg.isVirtual())
347*06c3fb27SDimitry Andric       return false;
348*06c3fb27SDimitry Andric     MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
349*06c3fb27SDimitry Andric     if (!SrcMI)
350*06c3fb27SDimitry Andric       return false;
351*06c3fb27SDimitry Andric     // Add SrcMI to the worklist.
352*06c3fb27SDimitry Andric     Worklist.push_back(SrcMI);
353*06c3fb27SDimitry Andric     return true;
354*06c3fb27SDimitry Andric   };
355*06c3fb27SDimitry Andric 
356*06c3fb27SDimitry Andric   if (!AddRegDefToWorkList(SrcReg))
357*06c3fb27SDimitry Andric     return false;
358*06c3fb27SDimitry Andric 
359*06c3fb27SDimitry Andric   while (!Worklist.empty()) {
360*06c3fb27SDimitry Andric     MachineInstr *MI = Worklist.pop_back_val();
361*06c3fb27SDimitry Andric 
362*06c3fb27SDimitry Andric     // If we already visited this instruction, we don't need to check it again.
363*06c3fb27SDimitry Andric     if (!Visited.insert(MI).second)
364*06c3fb27SDimitry Andric       continue;
365*06c3fb27SDimitry Andric 
366*06c3fb27SDimitry Andric     // If this is a sign extending operation we don't need to look any further.
367*06c3fb27SDimitry Andric     if (isSignExtendingOpW(*MI, MRI))
368*06c3fb27SDimitry Andric       continue;
369*06c3fb27SDimitry Andric 
370*06c3fb27SDimitry Andric     // Is this an instruction that propagates sign extend?
371*06c3fb27SDimitry Andric     switch (MI->getOpcode()) {
372*06c3fb27SDimitry Andric     default:
373*06c3fb27SDimitry Andric       // Unknown opcode, give up.
374*06c3fb27SDimitry Andric       return false;
375*06c3fb27SDimitry Andric     case RISCV::COPY: {
376*06c3fb27SDimitry Andric       const MachineFunction *MF = MI->getMF();
377*06c3fb27SDimitry Andric       const RISCVMachineFunctionInfo *RVFI =
378*06c3fb27SDimitry Andric           MF->getInfo<RISCVMachineFunctionInfo>();
379*06c3fb27SDimitry Andric 
380*06c3fb27SDimitry Andric       // If this is the entry block and the register is livein, see if we know
381*06c3fb27SDimitry Andric       // it is sign extended.
382*06c3fb27SDimitry Andric       if (MI->getParent() == &MF->front()) {
383*06c3fb27SDimitry Andric         Register VReg = MI->getOperand(0).getReg();
384*06c3fb27SDimitry Andric         if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
385*06c3fb27SDimitry Andric           continue;
386*06c3fb27SDimitry Andric       }
387*06c3fb27SDimitry Andric 
388*06c3fb27SDimitry Andric       Register CopySrcReg = MI->getOperand(1).getReg();
389*06c3fb27SDimitry Andric       if (CopySrcReg == RISCV::X10) {
390*06c3fb27SDimitry Andric         // For a method return value, we check the ZExt/SExt flags in attribute.
391*06c3fb27SDimitry Andric         // We assume the following code sequence for method call.
392*06c3fb27SDimitry Andric         // PseudoCALL @bar, ...
393*06c3fb27SDimitry Andric         // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
394*06c3fb27SDimitry Andric         // %0:gpr = COPY $x10
395*06c3fb27SDimitry Andric         //
396*06c3fb27SDimitry Andric         // We use the PseudoCall to look up the IR function being called to find
397*06c3fb27SDimitry Andric         // its return attributes.
398*06c3fb27SDimitry Andric         const MachineBasicBlock *MBB = MI->getParent();
399*06c3fb27SDimitry Andric         auto II = MI->getIterator();
400*06c3fb27SDimitry Andric         if (II == MBB->instr_begin() ||
401*06c3fb27SDimitry Andric             (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
402*06c3fb27SDimitry Andric           return false;
403*06c3fb27SDimitry Andric 
404*06c3fb27SDimitry Andric         const MachineInstr &CallMI = *(--II);
405*06c3fb27SDimitry Andric         if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
406*06c3fb27SDimitry Andric           return false;
407*06c3fb27SDimitry Andric 
408*06c3fb27SDimitry Andric         auto *CalleeFn =
409*06c3fb27SDimitry Andric             dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
410*06c3fb27SDimitry Andric         if (!CalleeFn)
411*06c3fb27SDimitry Andric           return false;
412*06c3fb27SDimitry Andric 
413*06c3fb27SDimitry Andric         auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
414*06c3fb27SDimitry Andric         if (!IntTy)
415*06c3fb27SDimitry Andric           return false;
416*06c3fb27SDimitry Andric 
417*06c3fb27SDimitry Andric         const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
418*06c3fb27SDimitry Andric         unsigned BitWidth = IntTy->getBitWidth();
419*06c3fb27SDimitry Andric         if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
420*06c3fb27SDimitry Andric             (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
421*06c3fb27SDimitry Andric           continue;
422*06c3fb27SDimitry Andric       }
423*06c3fb27SDimitry Andric 
424*06c3fb27SDimitry Andric       if (!AddRegDefToWorkList(CopySrcReg))
425*06c3fb27SDimitry Andric         return false;
426*06c3fb27SDimitry Andric 
427*06c3fb27SDimitry Andric       break;
428*06c3fb27SDimitry Andric     }
429*06c3fb27SDimitry Andric 
430*06c3fb27SDimitry Andric     // For these, we just need to check if the 1st operand is sign extended.
431*06c3fb27SDimitry Andric     case RISCV::BCLRI:
432*06c3fb27SDimitry Andric     case RISCV::BINVI:
433*06c3fb27SDimitry Andric     case RISCV::BSETI:
434*06c3fb27SDimitry Andric       if (MI->getOperand(2).getImm() >= 31)
435*06c3fb27SDimitry Andric         return false;
436*06c3fb27SDimitry Andric       [[fallthrough]];
437*06c3fb27SDimitry Andric     case RISCV::REM:
438*06c3fb27SDimitry Andric     case RISCV::ANDI:
439*06c3fb27SDimitry Andric     case RISCV::ORI:
440*06c3fb27SDimitry Andric     case RISCV::XORI:
441*06c3fb27SDimitry Andric       // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
442*06c3fb27SDimitry Andric       // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
443*06c3fb27SDimitry Andric       // Logical operations use a sign extended 12-bit immediate.
444*06c3fb27SDimitry Andric       if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
445*06c3fb27SDimitry Andric         return false;
446*06c3fb27SDimitry Andric 
447*06c3fb27SDimitry Andric       break;
448*06c3fb27SDimitry Andric     case RISCV::PseudoCCADDW:
449*06c3fb27SDimitry Andric     case RISCV::PseudoCCSUBW:
450*06c3fb27SDimitry Andric       // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to
451*06c3fb27SDimitry Andric       // check if operand 4 is sign extended.
452*06c3fb27SDimitry Andric       if (!AddRegDefToWorkList(MI->getOperand(4).getReg()))
453*06c3fb27SDimitry Andric         return false;
454*06c3fb27SDimitry Andric       break;
455*06c3fb27SDimitry Andric     case RISCV::REMU:
456*06c3fb27SDimitry Andric     case RISCV::AND:
457*06c3fb27SDimitry Andric     case RISCV::OR:
458*06c3fb27SDimitry Andric     case RISCV::XOR:
459*06c3fb27SDimitry Andric     case RISCV::ANDN:
460*06c3fb27SDimitry Andric     case RISCV::ORN:
461*06c3fb27SDimitry Andric     case RISCV::XNOR:
462*06c3fb27SDimitry Andric     case RISCV::MAX:
463*06c3fb27SDimitry Andric     case RISCV::MAXU:
464*06c3fb27SDimitry Andric     case RISCV::MIN:
465*06c3fb27SDimitry Andric     case RISCV::MINU:
466*06c3fb27SDimitry Andric     case RISCV::PseudoCCMOVGPR:
467*06c3fb27SDimitry Andric     case RISCV::PseudoCCAND:
468*06c3fb27SDimitry Andric     case RISCV::PseudoCCOR:
469*06c3fb27SDimitry Andric     case RISCV::PseudoCCXOR:
470*06c3fb27SDimitry Andric     case RISCV::PHI: {
471*06c3fb27SDimitry Andric       // If all incoming values are sign-extended, the output of AND, OR, XOR,
472*06c3fb27SDimitry Andric       // MIN, MAX, or PHI is also sign-extended.
473*06c3fb27SDimitry Andric 
474*06c3fb27SDimitry Andric       // The input registers for PHI are operand 1, 3, ...
475*06c3fb27SDimitry Andric       // The input registers for PseudoCCMOVGPR are 4 and 5.
476*06c3fb27SDimitry Andric       // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
477*06c3fb27SDimitry Andric       // The input registers for others are operand 1 and 2.
478*06c3fb27SDimitry Andric       unsigned B = 1, E = 3, D = 1;
479*06c3fb27SDimitry Andric       switch (MI->getOpcode()) {
480*06c3fb27SDimitry Andric       case RISCV::PHI:
481*06c3fb27SDimitry Andric         E = MI->getNumOperands();
482*06c3fb27SDimitry Andric         D = 2;
483*06c3fb27SDimitry Andric         break;
484*06c3fb27SDimitry Andric       case RISCV::PseudoCCMOVGPR:
485*06c3fb27SDimitry Andric         B = 4;
486*06c3fb27SDimitry Andric         E = 6;
487*06c3fb27SDimitry Andric         break;
488*06c3fb27SDimitry Andric       case RISCV::PseudoCCAND:
489*06c3fb27SDimitry Andric       case RISCV::PseudoCCOR:
490*06c3fb27SDimitry Andric       case RISCV::PseudoCCXOR:
491*06c3fb27SDimitry Andric         B = 4;
492*06c3fb27SDimitry Andric         E = 7;
493*06c3fb27SDimitry Andric         break;
494*06c3fb27SDimitry Andric        }
495*06c3fb27SDimitry Andric 
496*06c3fb27SDimitry Andric       for (unsigned I = B; I != E; I += D) {
497*06c3fb27SDimitry Andric         if (!MI->getOperand(I).isReg())
498*06c3fb27SDimitry Andric           return false;
499*06c3fb27SDimitry Andric 
500*06c3fb27SDimitry Andric         if (!AddRegDefToWorkList(MI->getOperand(I).getReg()))
501*06c3fb27SDimitry Andric           return false;
502*06c3fb27SDimitry Andric       }
503*06c3fb27SDimitry Andric 
504*06c3fb27SDimitry Andric       break;
505*06c3fb27SDimitry Andric     }
506*06c3fb27SDimitry Andric 
507*06c3fb27SDimitry Andric     case RISCV::VT_MASKC:
508*06c3fb27SDimitry Andric     case RISCV::VT_MASKCN:
509*06c3fb27SDimitry Andric       // Instructions return zero or operand 1. Result is sign extended if
510*06c3fb27SDimitry Andric       // operand 1 is sign extended.
511*06c3fb27SDimitry Andric       if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
512*06c3fb27SDimitry Andric         return false;
513*06c3fb27SDimitry Andric       break;
514*06c3fb27SDimitry Andric 
515*06c3fb27SDimitry Andric     // With these opcode, we can "fix" them with the W-version
516*06c3fb27SDimitry Andric     // if we know all users of the result only rely on bits 31:0
517*06c3fb27SDimitry Andric     case RISCV::SLLI:
518*06c3fb27SDimitry Andric       // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
519*06c3fb27SDimitry Andric       if (MI->getOperand(2).getImm() >= 32)
520*06c3fb27SDimitry Andric         return false;
521*06c3fb27SDimitry Andric       [[fallthrough]];
522*06c3fb27SDimitry Andric     case RISCV::ADDI:
523*06c3fb27SDimitry Andric     case RISCV::ADD:
524*06c3fb27SDimitry Andric     case RISCV::LD:
525*06c3fb27SDimitry Andric     case RISCV::LWU:
526*06c3fb27SDimitry Andric     case RISCV::MUL:
527*06c3fb27SDimitry Andric     case RISCV::SUB:
528*06c3fb27SDimitry Andric       if (hasAllWUsers(*MI, ST, MRI)) {
529*06c3fb27SDimitry Andric         FixableDef.insert(MI);
530*06c3fb27SDimitry Andric         break;
531*06c3fb27SDimitry Andric       }
532*06c3fb27SDimitry Andric       return false;
533*06c3fb27SDimitry Andric     }
534*06c3fb27SDimitry Andric   }
535*06c3fb27SDimitry Andric 
536*06c3fb27SDimitry Andric   // If we get here, then every node we visited produces a sign extended value
537*06c3fb27SDimitry Andric   // or propagated sign extended values. So the result must be sign extended.
538*06c3fb27SDimitry Andric   return true;
539*06c3fb27SDimitry Andric }
540*06c3fb27SDimitry Andric 
541*06c3fb27SDimitry Andric static unsigned getWOp(unsigned Opcode) {
542*06c3fb27SDimitry Andric   switch (Opcode) {
543*06c3fb27SDimitry Andric   case RISCV::ADDI:
544*06c3fb27SDimitry Andric     return RISCV::ADDIW;
545*06c3fb27SDimitry Andric   case RISCV::ADD:
546*06c3fb27SDimitry Andric     return RISCV::ADDW;
547*06c3fb27SDimitry Andric   case RISCV::LD:
548*06c3fb27SDimitry Andric   case RISCV::LWU:
549*06c3fb27SDimitry Andric     return RISCV::LW;
550*06c3fb27SDimitry Andric   case RISCV::MUL:
551*06c3fb27SDimitry Andric     return RISCV::MULW;
552*06c3fb27SDimitry Andric   case RISCV::SLLI:
553*06c3fb27SDimitry Andric     return RISCV::SLLIW;
554*06c3fb27SDimitry Andric   case RISCV::SUB:
555*06c3fb27SDimitry Andric     return RISCV::SUBW;
556*06c3fb27SDimitry Andric   default:
557*06c3fb27SDimitry Andric     llvm_unreachable("Unexpected opcode for replacement with W variant");
558*06c3fb27SDimitry Andric   }
559*06c3fb27SDimitry Andric }
560*06c3fb27SDimitry Andric 
561*06c3fb27SDimitry Andric bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
562*06c3fb27SDimitry Andric                                         const RISCVInstrInfo &TII,
563*06c3fb27SDimitry Andric                                         const RISCVSubtarget &ST,
564*06c3fb27SDimitry Andric                                         MachineRegisterInfo &MRI) {
565*06c3fb27SDimitry Andric   if (DisableSExtWRemoval)
566*06c3fb27SDimitry Andric     return false;
567*06c3fb27SDimitry Andric 
568*06c3fb27SDimitry Andric   bool MadeChange = false;
569*06c3fb27SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
570*06c3fb27SDimitry Andric     for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
571*06c3fb27SDimitry Andric       MachineInstr *MI = &*I++;
572*06c3fb27SDimitry Andric 
573*06c3fb27SDimitry Andric       // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
574*06c3fb27SDimitry Andric       if (!RISCV::isSEXT_W(*MI))
575*06c3fb27SDimitry Andric         continue;
576*06c3fb27SDimitry Andric 
577*06c3fb27SDimitry Andric       Register SrcReg = MI->getOperand(1).getReg();
578*06c3fb27SDimitry Andric 
579*06c3fb27SDimitry Andric       SmallPtrSet<MachineInstr *, 4> FixableDefs;
580*06c3fb27SDimitry Andric 
581*06c3fb27SDimitry Andric       // If all users only use the lower bits, this sext.w is redundant.
582*06c3fb27SDimitry Andric       // Or if all definitions reaching MI sign-extend their output,
583*06c3fb27SDimitry Andric       // then sext.w is redundant.
584*06c3fb27SDimitry Andric       if (!hasAllWUsers(*MI, ST, MRI) &&
585*06c3fb27SDimitry Andric           !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
586*06c3fb27SDimitry Andric         continue;
587*06c3fb27SDimitry Andric 
588*06c3fb27SDimitry Andric       Register DstReg = MI->getOperand(0).getReg();
589*06c3fb27SDimitry Andric       if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
590*06c3fb27SDimitry Andric         continue;
591*06c3fb27SDimitry Andric 
592*06c3fb27SDimitry Andric       // Convert Fixable instructions to their W versions.
593*06c3fb27SDimitry Andric       for (MachineInstr *Fixable : FixableDefs) {
594*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
595*06c3fb27SDimitry Andric         Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
596*06c3fb27SDimitry Andric         Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
597*06c3fb27SDimitry Andric         Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
598*06c3fb27SDimitry Andric         Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
599*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "     with " << *Fixable);
600*06c3fb27SDimitry Andric         ++NumTransformedToWInstrs;
601*06c3fb27SDimitry Andric       }
602*06c3fb27SDimitry Andric 
603*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
604*06c3fb27SDimitry Andric       MRI.replaceRegWith(DstReg, SrcReg);
605*06c3fb27SDimitry Andric       MRI.clearKillFlags(SrcReg);
606*06c3fb27SDimitry Andric       MI->eraseFromParent();
607*06c3fb27SDimitry Andric       ++NumRemovedSExtW;
608*06c3fb27SDimitry Andric       MadeChange = true;
609*06c3fb27SDimitry Andric     }
610*06c3fb27SDimitry Andric   }
611*06c3fb27SDimitry Andric 
612*06c3fb27SDimitry Andric   return MadeChange;
613*06c3fb27SDimitry Andric }
614*06c3fb27SDimitry Andric 
615*06c3fb27SDimitry Andric bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
616*06c3fb27SDimitry Andric                                      const RISCVInstrInfo &TII,
617*06c3fb27SDimitry Andric                                      const RISCVSubtarget &ST,
618*06c3fb27SDimitry Andric                                      MachineRegisterInfo &MRI) {
619*06c3fb27SDimitry Andric   if (DisableStripWSuffix)
620*06c3fb27SDimitry Andric     return false;
621*06c3fb27SDimitry Andric 
622*06c3fb27SDimitry Andric   bool MadeChange = false;
623*06c3fb27SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
624*06c3fb27SDimitry Andric     for (auto I = MBB.begin(), IE = MBB.end(); I != IE; ++I) {
625*06c3fb27SDimitry Andric       MachineInstr &MI = *I;
626*06c3fb27SDimitry Andric 
627*06c3fb27SDimitry Andric       unsigned Opc;
628*06c3fb27SDimitry Andric       switch (MI.getOpcode()) {
629*06c3fb27SDimitry Andric       default:
630*06c3fb27SDimitry Andric         continue;
631*06c3fb27SDimitry Andric       case RISCV::ADDW:  Opc = RISCV::ADD;  break;
632*06c3fb27SDimitry Andric       case RISCV::MULW:  Opc = RISCV::MUL;  break;
633*06c3fb27SDimitry Andric       case RISCV::SLLIW: Opc = RISCV::SLLI; break;
634*06c3fb27SDimitry Andric       }
635*06c3fb27SDimitry Andric 
636*06c3fb27SDimitry Andric       if (hasAllWUsers(MI, ST, MRI)) {
637*06c3fb27SDimitry Andric         MI.setDesc(TII.get(Opc));
638*06c3fb27SDimitry Andric         MadeChange = true;
639*06c3fb27SDimitry Andric       }
640*06c3fb27SDimitry Andric     }
641*06c3fb27SDimitry Andric   }
642*06c3fb27SDimitry Andric 
643*06c3fb27SDimitry Andric   return MadeChange;
644*06c3fb27SDimitry Andric }
645*06c3fb27SDimitry Andric 
646*06c3fb27SDimitry Andric bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
647*06c3fb27SDimitry Andric   if (skipFunction(MF.getFunction()))
648*06c3fb27SDimitry Andric     return false;
649*06c3fb27SDimitry Andric 
650*06c3fb27SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
651*06c3fb27SDimitry Andric   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
652*06c3fb27SDimitry Andric   const RISCVInstrInfo &TII = *ST.getInstrInfo();
653*06c3fb27SDimitry Andric 
654*06c3fb27SDimitry Andric   if (!ST.is64Bit())
655*06c3fb27SDimitry Andric     return false;
656*06c3fb27SDimitry Andric 
657*06c3fb27SDimitry Andric   bool MadeChange = false;
658*06c3fb27SDimitry Andric   MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
659*06c3fb27SDimitry Andric   MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
660*06c3fb27SDimitry Andric 
661*06c3fb27SDimitry Andric   return MadeChange;
662*06c3fb27SDimitry Andric }
663