xref: /llvm-project/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp (revision 0a4e1c518bbca5f3bced6ded6dd71d2fe6622ac3)
1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass does some optimizations for *W instructions at the MI level.
10 //
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
14 //
15 // Then:
16 // 1. Unless explicit disabled or the target prefers instructions with W suffix,
17 //    it removes the -w suffix from opw instructions whenever all users are
18 //    dependent only on the lower word of the result of the instruction.
19 //    The cases handled are:
20 //    * addw because c.add has a larger register encoding than c.addw.
21 //    * addiw because it helps reduce test differences between RV32 and RV64
22 //      w/o being a pessimization.
23 //    * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24 //    * slliw because c.slliw doesn't exist and c.slli does
25 //
26 // 2. Or if explicit enabled or the target prefers instructions with W suffix,
27 //    it adds the W suffix to the instruction whenever all users are dependent
28 //    only on the lower word of the result of the instruction.
29 //    The cases handled are:
30 //    * add/addi/sub/mul.
31 //    * slli with imm < 32.
32 //    * ld/lwu.
33 //===---------------------------------------------------------------------===//
34 
35 #include "RISCV.h"
36 #include "RISCVMachineFunctionInfo.h"
37 #include "RISCVSubtarget.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/TargetInstrInfo.h"
42 
43 using namespace llvm;
44 
45 #define DEBUG_TYPE "riscv-opt-w-instrs"
46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47 
48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49 STATISTIC(NumTransformedToWInstrs,
50           "Number of instructions transformed to W-ops");
51 
52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53                                          cl::desc("Disable removal of sext.w"),
54                                          cl::init(false), cl::Hidden);
55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56                                          cl::desc("Disable strip W suffix"),
57                                          cl::init(false), cl::Hidden);
58 
59 namespace {
60 
61 class RISCVOptWInstrs : public MachineFunctionPass {
62 public:
63   static char ID;
64 
65   RISCVOptWInstrs() : MachineFunctionPass(ID) {}
66 
67   bool runOnMachineFunction(MachineFunction &MF) override;
68   bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
69                          const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
70   bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
71                       const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
72   bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73                        const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
74 
75   void getAnalysisUsage(AnalysisUsage &AU) const override {
76     AU.setPreservesCFG();
77     MachineFunctionPass::getAnalysisUsage(AU);
78   }
79 
80   StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
81 };
82 
83 } // end anonymous namespace
84 
85 char RISCVOptWInstrs::ID = 0;
86 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
87                 false)
88 
89 FunctionPass *llvm::createRISCVOptWInstrsPass() {
90   return new RISCVOptWInstrs();
91 }
92 
93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
94                                         unsigned Bits) {
95   const MachineInstr &MI = *UserOp.getParent();
96   unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
97 
98   if (!MCOpcode)
99     return false;
100 
101   const MCInstrDesc &MCID = MI.getDesc();
102   const uint64_t TSFlags = MCID.TSFlags;
103   if (!RISCVII::hasSEWOp(TSFlags))
104     return false;
105   assert(RISCVII::hasVLOp(TSFlags));
106   const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
107 
108   if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
109     return false;
110 
111   auto NumDemandedBits =
112       RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
113   return NumDemandedBits && Bits >= *NumDemandedBits;
114 }
115 
116 // Checks if all users only demand the lower \p OrigBits of the original
117 // instruction's result.
118 // TODO: handle multiple interdependent transformations
119 static bool hasAllNBitUsers(const MachineInstr &OrigMI,
120                             const RISCVSubtarget &ST,
121                             const MachineRegisterInfo &MRI, unsigned OrigBits) {
122 
123   SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
124   SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
125 
126   Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
127 
128   while (!Worklist.empty()) {
129     auto P = Worklist.pop_back_val();
130     const MachineInstr *MI = P.first;
131     unsigned Bits = P.second;
132 
133     if (!Visited.insert(P).second)
134       continue;
135 
136     // Only handle instructions with one def.
137     if (MI->getNumExplicitDefs() != 1)
138       return false;
139 
140     Register DestReg = MI->getOperand(0).getReg();
141     if (!DestReg.isVirtual())
142       return false;
143 
144     for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
145       const MachineInstr *UserMI = UserOp.getParent();
146       unsigned OpIdx = UserOp.getOperandNo();
147 
148       switch (UserMI->getOpcode()) {
149       default:
150         if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
151           break;
152         return false;
153 
154       case RISCV::ADDIW:
155       case RISCV::ADDW:
156       case RISCV::DIVUW:
157       case RISCV::DIVW:
158       case RISCV::MULW:
159       case RISCV::REMUW:
160       case RISCV::REMW:
161       case RISCV::SLLIW:
162       case RISCV::SLLW:
163       case RISCV::SRAIW:
164       case RISCV::SRAW:
165       case RISCV::SRLIW:
166       case RISCV::SRLW:
167       case RISCV::SUBW:
168       case RISCV::ROLW:
169       case RISCV::RORW:
170       case RISCV::RORIW:
171       case RISCV::CLZW:
172       case RISCV::CTZW:
173       case RISCV::CPOPW:
174       case RISCV::SLLI_UW:
175       case RISCV::FMV_W_X:
176       case RISCV::FCVT_H_W:
177       case RISCV::FCVT_H_W_INX:
178       case RISCV::FCVT_H_WU:
179       case RISCV::FCVT_H_WU_INX:
180       case RISCV::FCVT_S_W:
181       case RISCV::FCVT_S_W_INX:
182       case RISCV::FCVT_S_WU:
183       case RISCV::FCVT_S_WU_INX:
184       case RISCV::FCVT_D_W:
185       case RISCV::FCVT_D_W_INX:
186       case RISCV::FCVT_D_WU:
187       case RISCV::FCVT_D_WU_INX:
188         if (Bits >= 32)
189           break;
190         return false;
191       case RISCV::SEXT_B:
192       case RISCV::PACKH:
193         if (Bits >= 8)
194           break;
195         return false;
196       case RISCV::SEXT_H:
197       case RISCV::FMV_H_X:
198       case RISCV::ZEXT_H_RV32:
199       case RISCV::ZEXT_H_RV64:
200       case RISCV::PACKW:
201         if (Bits >= 16)
202           break;
203         return false;
204 
205       case RISCV::PACK:
206         if (Bits >= (ST.getXLen() / 2))
207           break;
208         return false;
209 
210       case RISCV::SRLI: {
211         // If we are shifting right by less than Bits, and users don't demand
212         // any bits that were shifted into [Bits-1:0], then we can consider this
213         // as an N-Bit user.
214         unsigned ShAmt = UserMI->getOperand(2).getImm();
215         if (Bits > ShAmt) {
216           Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
217           break;
218         }
219         return false;
220       }
221 
222       // these overwrite higher input bits, otherwise the lower word of output
223       // depends only on the lower word of input. So check their uses read W.
224       case RISCV::SLLI: {
225         unsigned ShAmt = UserMI->getOperand(2).getImm();
226         if (Bits >= (ST.getXLen() - ShAmt))
227           break;
228         Worklist.push_back(std::make_pair(UserMI, Bits + ShAmt));
229         break;
230       }
231       case RISCV::ANDI: {
232         uint64_t Imm = UserMI->getOperand(2).getImm();
233         if (Bits >= (unsigned)llvm::bit_width(Imm))
234           break;
235         Worklist.push_back(std::make_pair(UserMI, Bits));
236         break;
237       }
238       case RISCV::ORI: {
239         uint64_t Imm = UserMI->getOperand(2).getImm();
240         if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
241           break;
242         Worklist.push_back(std::make_pair(UserMI, Bits));
243         break;
244       }
245 
246       case RISCV::SLL:
247       case RISCV::BSET:
248       case RISCV::BCLR:
249       case RISCV::BINV:
250         // Operand 2 is the shift amount which uses log2(xlen) bits.
251         if (OpIdx == 2) {
252           if (Bits >= Log2_32(ST.getXLen()))
253             break;
254           return false;
255         }
256         Worklist.push_back(std::make_pair(UserMI, Bits));
257         break;
258 
259       case RISCV::SRA:
260       case RISCV::SRL:
261       case RISCV::ROL:
262       case RISCV::ROR:
263         // Operand 2 is the shift amount which uses 6 bits.
264         if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
265           break;
266         return false;
267 
268       case RISCV::ADD_UW:
269       case RISCV::SH1ADD_UW:
270       case RISCV::SH2ADD_UW:
271       case RISCV::SH3ADD_UW:
272         // Operand 1 is implicitly zero extended.
273         if (OpIdx == 1 && Bits >= 32)
274           break;
275         Worklist.push_back(std::make_pair(UserMI, Bits));
276         break;
277 
278       case RISCV::BEXTI:
279         if (UserMI->getOperand(2).getImm() >= Bits)
280           return false;
281         break;
282 
283       case RISCV::SB:
284         // The first argument is the value to store.
285         if (OpIdx == 0 && Bits >= 8)
286           break;
287         return false;
288       case RISCV::SH:
289         // The first argument is the value to store.
290         if (OpIdx == 0 && Bits >= 16)
291           break;
292         return false;
293       case RISCV::SW:
294         // The first argument is the value to store.
295         if (OpIdx == 0 && Bits >= 32)
296           break;
297         return false;
298 
299       // For these, lower word of output in these operations, depends only on
300       // the lower word of input. So, we check all uses only read lower word.
301       case RISCV::COPY:
302       case RISCV::PHI:
303 
304       case RISCV::ADD:
305       case RISCV::ADDI:
306       case RISCV::AND:
307       case RISCV::MUL:
308       case RISCV::OR:
309       case RISCV::SUB:
310       case RISCV::XOR:
311       case RISCV::XORI:
312 
313       case RISCV::ANDN:
314       case RISCV::BREV8:
315       case RISCV::CLMUL:
316       case RISCV::ORC_B:
317       case RISCV::ORN:
318       case RISCV::SH1ADD:
319       case RISCV::SH2ADD:
320       case RISCV::SH3ADD:
321       case RISCV::XNOR:
322       case RISCV::BSETI:
323       case RISCV::BCLRI:
324       case RISCV::BINVI:
325         Worklist.push_back(std::make_pair(UserMI, Bits));
326         break;
327 
328       case RISCV::PseudoCCMOVGPR:
329         // Either operand 4 or operand 5 is returned by this instruction. If
330         // only the lower word of the result is used, then only the lower word
331         // of operand 4 and 5 is used.
332         if (OpIdx != 4 && OpIdx != 5)
333           return false;
334         Worklist.push_back(std::make_pair(UserMI, Bits));
335         break;
336 
337       case RISCV::CZERO_EQZ:
338       case RISCV::CZERO_NEZ:
339       case RISCV::VT_MASKC:
340       case RISCV::VT_MASKCN:
341         if (OpIdx != 1)
342           return false;
343         Worklist.push_back(std::make_pair(UserMI, Bits));
344         break;
345       }
346     }
347   }
348 
349   return true;
350 }
351 
352 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
353                          const MachineRegisterInfo &MRI) {
354   return hasAllNBitUsers(OrigMI, ST, MRI, 32);
355 }
356 
357 // This function returns true if the machine instruction always outputs a value
358 // where bits 63:32 match bit 31.
359 static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
360   uint64_t TSFlags = MI.getDesc().TSFlags;
361 
362   // Instructions that can be determined from opcode are marked in tablegen.
363   if (TSFlags & RISCVII::IsSignExtendingOpWMask)
364     return true;
365 
366   // Special cases that require checking operands.
367   switch (MI.getOpcode()) {
368   // shifting right sufficiently makes the value 32-bit sign-extended
369   case RISCV::SRAI:
370     return MI.getOperand(2).getImm() >= 32;
371   case RISCV::SRLI:
372     return MI.getOperand(2).getImm() > 32;
373   // The LI pattern ADDI rd, X0, imm is sign extended.
374   case RISCV::ADDI:
375     return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
376   // An ANDI with an 11 bit immediate will zero bits 63:11.
377   case RISCV::ANDI:
378     return isUInt<11>(MI.getOperand(2).getImm());
379   // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
380   case RISCV::ORI:
381     return !isUInt<11>(MI.getOperand(2).getImm());
382   // A bseti with X0 is sign extended if the immediate is less than 31.
383   case RISCV::BSETI:
384     return MI.getOperand(2).getImm() < 31 &&
385            MI.getOperand(1).getReg() == RISCV::X0;
386   // Copying from X0 produces zero.
387   case RISCV::COPY:
388     return MI.getOperand(1).getReg() == RISCV::X0;
389   // Ignore the scratch register destination.
390   case RISCV::PseudoAtomicLoadNand32:
391     return OpNo == 0;
392   case RISCV::PseudoVMV_X_S: {
393     // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
394     int64_t Log2SEW = MI.getOperand(2).getImm();
395     assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
396     return Log2SEW <= 5;
397   }
398   }
399 
400   return false;
401 }
402 
403 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
404                             const MachineRegisterInfo &MRI,
405                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
406   SmallSet<Register, 4> Visited;
407   SmallVector<Register, 4> Worklist;
408 
409   auto AddRegToWorkList = [&](Register SrcReg) {
410     if (!SrcReg.isVirtual())
411       return false;
412     Worklist.push_back(SrcReg);
413     return true;
414   };
415 
416   if (!AddRegToWorkList(SrcReg))
417     return false;
418 
419   while (!Worklist.empty()) {
420     Register Reg = Worklist.pop_back_val();
421 
422     // If we already visited this register, we don't need to check it again.
423     if (!Visited.insert(Reg).second)
424       continue;
425 
426     MachineInstr *MI = MRI.getVRegDef(Reg);
427     if (!MI)
428       continue;
429 
430     int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
431     assert(OpNo != -1 && "Couldn't find register");
432 
433     // If this is a sign extending operation we don't need to look any further.
434     if (isSignExtendingOpW(*MI, OpNo))
435       continue;
436 
437     // Is this an instruction that propagates sign extend?
438     switch (MI->getOpcode()) {
439     default:
440       // Unknown opcode, give up.
441       return false;
442     case RISCV::COPY: {
443       const MachineFunction *MF = MI->getMF();
444       const RISCVMachineFunctionInfo *RVFI =
445           MF->getInfo<RISCVMachineFunctionInfo>();
446 
447       // If this is the entry block and the register is livein, see if we know
448       // it is sign extended.
449       if (MI->getParent() == &MF->front()) {
450         Register VReg = MI->getOperand(0).getReg();
451         if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
452           continue;
453       }
454 
455       Register CopySrcReg = MI->getOperand(1).getReg();
456       if (CopySrcReg == RISCV::X10) {
457         // For a method return value, we check the ZExt/SExt flags in attribute.
458         // We assume the following code sequence for method call.
459         // PseudoCALL @bar, ...
460         // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
461         // %0:gpr = COPY $x10
462         //
463         // We use the PseudoCall to look up the IR function being called to find
464         // its return attributes.
465         const MachineBasicBlock *MBB = MI->getParent();
466         auto II = MI->getIterator();
467         if (II == MBB->instr_begin() ||
468             (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
469           return false;
470 
471         const MachineInstr &CallMI = *(--II);
472         if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
473           return false;
474 
475         auto *CalleeFn =
476             dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
477         if (!CalleeFn)
478           return false;
479 
480         auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
481         if (!IntTy)
482           return false;
483 
484         const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
485         unsigned BitWidth = IntTy->getBitWidth();
486         if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
487             (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
488           continue;
489       }
490 
491       if (!AddRegToWorkList(CopySrcReg))
492         return false;
493 
494       break;
495     }
496 
497     // For these, we just need to check if the 1st operand is sign extended.
498     case RISCV::BCLRI:
499     case RISCV::BINVI:
500     case RISCV::BSETI:
501       if (MI->getOperand(2).getImm() >= 31)
502         return false;
503       [[fallthrough]];
504     case RISCV::REM:
505     case RISCV::ANDI:
506     case RISCV::ORI:
507     case RISCV::XORI:
508       // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
509       // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
510       // Logical operations use a sign extended 12-bit immediate.
511       if (!AddRegToWorkList(MI->getOperand(1).getReg()))
512         return false;
513 
514       break;
515     case RISCV::PseudoCCADDW:
516     case RISCV::PseudoCCADDIW:
517     case RISCV::PseudoCCSUBW:
518     case RISCV::PseudoCCSLLW:
519     case RISCV::PseudoCCSRLW:
520     case RISCV::PseudoCCSRAW:
521     case RISCV::PseudoCCSLLIW:
522     case RISCV::PseudoCCSRLIW:
523     case RISCV::PseudoCCSRAIW:
524       // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
525       // need to check if operand 4 is sign extended.
526       if (!AddRegToWorkList(MI->getOperand(4).getReg()))
527         return false;
528       break;
529     case RISCV::REMU:
530     case RISCV::AND:
531     case RISCV::OR:
532     case RISCV::XOR:
533     case RISCV::ANDN:
534     case RISCV::ORN:
535     case RISCV::XNOR:
536     case RISCV::MAX:
537     case RISCV::MAXU:
538     case RISCV::MIN:
539     case RISCV::MINU:
540     case RISCV::PseudoCCMOVGPR:
541     case RISCV::PseudoCCAND:
542     case RISCV::PseudoCCOR:
543     case RISCV::PseudoCCXOR:
544     case RISCV::PHI: {
545       // If all incoming values are sign-extended, the output of AND, OR, XOR,
546       // MIN, MAX, or PHI is also sign-extended.
547 
548       // The input registers for PHI are operand 1, 3, ...
549       // The input registers for PseudoCCMOVGPR are 4 and 5.
550       // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
551       // The input registers for others are operand 1 and 2.
552       unsigned B = 1, E = 3, D = 1;
553       switch (MI->getOpcode()) {
554       case RISCV::PHI:
555         E = MI->getNumOperands();
556         D = 2;
557         break;
558       case RISCV::PseudoCCMOVGPR:
559         B = 4;
560         E = 6;
561         break;
562       case RISCV::PseudoCCAND:
563       case RISCV::PseudoCCOR:
564       case RISCV::PseudoCCXOR:
565         B = 4;
566         E = 7;
567         break;
568        }
569 
570       for (unsigned I = B; I != E; I += D) {
571         if (!MI->getOperand(I).isReg())
572           return false;
573 
574         if (!AddRegToWorkList(MI->getOperand(I).getReg()))
575           return false;
576       }
577 
578       break;
579     }
580 
581     case RISCV::CZERO_EQZ:
582     case RISCV::CZERO_NEZ:
583     case RISCV::VT_MASKC:
584     case RISCV::VT_MASKCN:
585       // Instructions return zero or operand 1. Result is sign extended if
586       // operand 1 is sign extended.
587       if (!AddRegToWorkList(MI->getOperand(1).getReg()))
588         return false;
589       break;
590 
591     // With these opcode, we can "fix" them with the W-version
592     // if we know all users of the result only rely on bits 31:0
593     case RISCV::SLLI:
594       // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
595       if (MI->getOperand(2).getImm() >= 32)
596         return false;
597       [[fallthrough]];
598     case RISCV::ADDI:
599     case RISCV::ADD:
600     case RISCV::LD:
601     case RISCV::LWU:
602     case RISCV::MUL:
603     case RISCV::SUB:
604       if (hasAllWUsers(*MI, ST, MRI)) {
605         FixableDef.insert(MI);
606         break;
607       }
608       return false;
609     }
610   }
611 
612   // If we get here, then every node we visited produces a sign extended value
613   // or propagated sign extended values. So the result must be sign extended.
614   return true;
615 }
616 
617 static unsigned getWOp(unsigned Opcode) {
618   switch (Opcode) {
619   case RISCV::ADDI:
620     return RISCV::ADDIW;
621   case RISCV::ADD:
622     return RISCV::ADDW;
623   case RISCV::LD:
624   case RISCV::LWU:
625     return RISCV::LW;
626   case RISCV::MUL:
627     return RISCV::MULW;
628   case RISCV::SLLI:
629     return RISCV::SLLIW;
630   case RISCV::SUB:
631     return RISCV::SUBW;
632   default:
633     llvm_unreachable("Unexpected opcode for replacement with W variant");
634   }
635 }
636 
637 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
638                                         const RISCVInstrInfo &TII,
639                                         const RISCVSubtarget &ST,
640                                         MachineRegisterInfo &MRI) {
641   if (DisableSExtWRemoval)
642     return false;
643 
644   bool MadeChange = false;
645   for (MachineBasicBlock &MBB : MF) {
646     for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
647       // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
648       if (!RISCV::isSEXT_W(MI))
649         continue;
650 
651       Register SrcReg = MI.getOperand(1).getReg();
652 
653       SmallPtrSet<MachineInstr *, 4> FixableDefs;
654 
655       // If all users only use the lower bits, this sext.w is redundant.
656       // Or if all definitions reaching MI sign-extend their output,
657       // then sext.w is redundant.
658       if (!hasAllWUsers(MI, ST, MRI) &&
659           !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
660         continue;
661 
662       Register DstReg = MI.getOperand(0).getReg();
663       if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
664         continue;
665 
666       // Convert Fixable instructions to their W versions.
667       for (MachineInstr *Fixable : FixableDefs) {
668         LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
669         Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
670         Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
671         Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
672         Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
673         LLVM_DEBUG(dbgs() << "     with " << *Fixable);
674         ++NumTransformedToWInstrs;
675       }
676 
677       LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
678       MRI.replaceRegWith(DstReg, SrcReg);
679       MRI.clearKillFlags(SrcReg);
680       MI.eraseFromParent();
681       ++NumRemovedSExtW;
682       MadeChange = true;
683     }
684   }
685 
686   return MadeChange;
687 }
688 
689 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
690                                      const RISCVInstrInfo &TII,
691                                      const RISCVSubtarget &ST,
692                                      MachineRegisterInfo &MRI) {
693   bool MadeChange = false;
694   for (MachineBasicBlock &MBB : MF) {
695     for (MachineInstr &MI : MBB) {
696       unsigned Opc;
697       switch (MI.getOpcode()) {
698       default:
699         continue;
700       case RISCV::ADDW:  Opc = RISCV::ADD;  break;
701       case RISCV::ADDIW: Opc = RISCV::ADDI; break;
702       case RISCV::MULW:  Opc = RISCV::MUL;  break;
703       case RISCV::SLLIW: Opc = RISCV::SLLI; break;
704       }
705 
706       if (hasAllWUsers(MI, ST, MRI)) {
707         MI.setDesc(TII.get(Opc));
708         MadeChange = true;
709       }
710     }
711   }
712 
713   return MadeChange;
714 }
715 
716 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF,
717                                       const RISCVInstrInfo &TII,
718                                       const RISCVSubtarget &ST,
719                                       MachineRegisterInfo &MRI) {
720   bool MadeChange = false;
721   for (MachineBasicBlock &MBB : MF) {
722     for (MachineInstr &MI : MBB) {
723       unsigned WOpc;
724       // TODO: Add more?
725       switch (MI.getOpcode()) {
726       default:
727         continue;
728       case RISCV::ADD:
729         WOpc = RISCV::ADDW;
730         break;
731       case RISCV::ADDI:
732         WOpc = RISCV::ADDIW;
733         break;
734       case RISCV::SUB:
735         WOpc = RISCV::SUBW;
736         break;
737       case RISCV::MUL:
738         WOpc = RISCV::MULW;
739         break;
740       case RISCV::SLLI:
741         // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
742         if (MI.getOperand(2).getImm() >= 32)
743           continue;
744         WOpc = RISCV::SLLIW;
745         break;
746       case RISCV::LD:
747       case RISCV::LWU:
748         WOpc = RISCV::LW;
749         break;
750       }
751 
752       if (hasAllWUsers(MI, ST, MRI)) {
753         LLVM_DEBUG(dbgs() << "Replacing " << MI);
754         MI.setDesc(TII.get(WOpc));
755         MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
756         MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
757         MI.clearFlag(MachineInstr::MIFlag::IsExact);
758         LLVM_DEBUG(dbgs() << "     with " << MI);
759         ++NumTransformedToWInstrs;
760         MadeChange = true;
761       }
762     }
763   }
764 
765   return MadeChange;
766 }
767 
768 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
769   if (skipFunction(MF.getFunction()))
770     return false;
771 
772   MachineRegisterInfo &MRI = MF.getRegInfo();
773   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
774   const RISCVInstrInfo &TII = *ST.getInstrInfo();
775 
776   if (!ST.is64Bit())
777     return false;
778 
779   bool MadeChange = false;
780   MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
781 
782   if (!(DisableStripWSuffix || ST.preferWInst()))
783     MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
784 
785   if (ST.preferWInst())
786     MadeChange |= appendWSuffixes(MF, TII, ST, MRI);
787 
788   return MadeChange;
789 }
790