1*06c3fb27SDimitry Andric //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2*06c3fb27SDimitry Andric // 3*06c3fb27SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*06c3fb27SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*06c3fb27SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*06c3fb27SDimitry Andric // 7*06c3fb27SDimitry Andric //===---------------------------------------------------------------------===// 8*06c3fb27SDimitry Andric // 9*06c3fb27SDimitry Andric // This pass does some optimizations for *W instructions at the MI level. 10*06c3fb27SDimitry Andric // 11*06c3fb27SDimitry Andric // First it removes unneeded sext.w instructions. Either because the sign 12*06c3fb27SDimitry Andric // extended bits aren't consumed or because the input was already sign extended 13*06c3fb27SDimitry Andric // by an earlier instruction. 14*06c3fb27SDimitry Andric // 15*06c3fb27SDimitry Andric // Then it removes the -w suffix from each addiw and slliw instructions 16*06c3fb27SDimitry Andric // whenever all users are dependent only on the lower word of the result of the 17*06c3fb27SDimitry Andric // instruction. We do this only for addiw, slliw, and mulw because the -w forms 18*06c3fb27SDimitry Andric // are less compressible. 19*06c3fb27SDimitry Andric // 20*06c3fb27SDimitry Andric //===---------------------------------------------------------------------===// 21*06c3fb27SDimitry Andric 22*06c3fb27SDimitry Andric #include "RISCV.h" 23*06c3fb27SDimitry Andric #include "RISCVMachineFunctionInfo.h" 24*06c3fb27SDimitry Andric #include "RISCVSubtarget.h" 25*06c3fb27SDimitry Andric #include "llvm/ADT/Statistic.h" 26*06c3fb27SDimitry Andric #include "llvm/CodeGen/MachineFunctionPass.h" 27*06c3fb27SDimitry Andric #include "llvm/CodeGen/TargetInstrInfo.h" 28*06c3fb27SDimitry Andric 29*06c3fb27SDimitry Andric using namespace llvm; 30*06c3fb27SDimitry Andric 31*06c3fb27SDimitry Andric #define DEBUG_TYPE "riscv-opt-w-instrs" 32*06c3fb27SDimitry Andric #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 33*06c3fb27SDimitry Andric 34*06c3fb27SDimitry Andric STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 35*06c3fb27SDimitry Andric STATISTIC(NumTransformedToWInstrs, 36*06c3fb27SDimitry Andric "Number of instructions transformed to W-ops"); 37*06c3fb27SDimitry Andric 38*06c3fb27SDimitry Andric static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 39*06c3fb27SDimitry Andric cl::desc("Disable removal of sext.w"), 40*06c3fb27SDimitry Andric cl::init(false), cl::Hidden); 41*06c3fb27SDimitry Andric static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 42*06c3fb27SDimitry Andric cl::desc("Disable strip W suffix"), 43*06c3fb27SDimitry Andric cl::init(false), cl::Hidden); 44*06c3fb27SDimitry Andric 45*06c3fb27SDimitry Andric namespace { 46*06c3fb27SDimitry Andric 47*06c3fb27SDimitry Andric class RISCVOptWInstrs : public MachineFunctionPass { 48*06c3fb27SDimitry Andric public: 49*06c3fb27SDimitry Andric static char ID; 50*06c3fb27SDimitry Andric 51*06c3fb27SDimitry Andric RISCVOptWInstrs() : MachineFunctionPass(ID) { 52*06c3fb27SDimitry Andric initializeRISCVOptWInstrsPass(*PassRegistry::getPassRegistry()); 53*06c3fb27SDimitry Andric } 54*06c3fb27SDimitry Andric 55*06c3fb27SDimitry Andric bool runOnMachineFunction(MachineFunction &MF) override; 56*06c3fb27SDimitry Andric bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 57*06c3fb27SDimitry Andric const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 58*06c3fb27SDimitry Andric bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 59*06c3fb27SDimitry Andric const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 60*06c3fb27SDimitry Andric 61*06c3fb27SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 62*06c3fb27SDimitry Andric AU.setPreservesCFG(); 63*06c3fb27SDimitry Andric MachineFunctionPass::getAnalysisUsage(AU); 64*06c3fb27SDimitry Andric } 65*06c3fb27SDimitry Andric 66*06c3fb27SDimitry Andric StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 67*06c3fb27SDimitry Andric }; 68*06c3fb27SDimitry Andric 69*06c3fb27SDimitry Andric } // end anonymous namespace 70*06c3fb27SDimitry Andric 71*06c3fb27SDimitry Andric char RISCVOptWInstrs::ID = 0; 72*06c3fb27SDimitry Andric INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 73*06c3fb27SDimitry Andric false) 74*06c3fb27SDimitry Andric 75*06c3fb27SDimitry Andric FunctionPass *llvm::createRISCVOptWInstrsPass() { 76*06c3fb27SDimitry Andric return new RISCVOptWInstrs(); 77*06c3fb27SDimitry Andric } 78*06c3fb27SDimitry Andric 79*06c3fb27SDimitry Andric // Checks if all users only demand the lower \p OrigBits of the original 80*06c3fb27SDimitry Andric // instruction's result. 81*06c3fb27SDimitry Andric // TODO: handle multiple interdependent transformations 82*06c3fb27SDimitry Andric static bool hasAllNBitUsers(const MachineInstr &OrigMI, 83*06c3fb27SDimitry Andric const RISCVSubtarget &ST, 84*06c3fb27SDimitry Andric const MachineRegisterInfo &MRI, unsigned OrigBits) { 85*06c3fb27SDimitry Andric 86*06c3fb27SDimitry Andric SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 87*06c3fb27SDimitry Andric SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 88*06c3fb27SDimitry Andric 89*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 90*06c3fb27SDimitry Andric 91*06c3fb27SDimitry Andric while (!Worklist.empty()) { 92*06c3fb27SDimitry Andric auto P = Worklist.pop_back_val(); 93*06c3fb27SDimitry Andric const MachineInstr *MI = P.first; 94*06c3fb27SDimitry Andric unsigned Bits = P.second; 95*06c3fb27SDimitry Andric 96*06c3fb27SDimitry Andric if (!Visited.insert(P).second) 97*06c3fb27SDimitry Andric continue; 98*06c3fb27SDimitry Andric 99*06c3fb27SDimitry Andric // Only handle instructions with one def. 100*06c3fb27SDimitry Andric if (MI->getNumExplicitDefs() != 1) 101*06c3fb27SDimitry Andric return false; 102*06c3fb27SDimitry Andric 103*06c3fb27SDimitry Andric for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { 104*06c3fb27SDimitry Andric const MachineInstr *UserMI = UserOp.getParent(); 105*06c3fb27SDimitry Andric unsigned OpIdx = UserOp.getOperandNo(); 106*06c3fb27SDimitry Andric 107*06c3fb27SDimitry Andric switch (UserMI->getOpcode()) { 108*06c3fb27SDimitry Andric default: 109*06c3fb27SDimitry Andric return false; 110*06c3fb27SDimitry Andric 111*06c3fb27SDimitry Andric case RISCV::ADDIW: 112*06c3fb27SDimitry Andric case RISCV::ADDW: 113*06c3fb27SDimitry Andric case RISCV::DIVUW: 114*06c3fb27SDimitry Andric case RISCV::DIVW: 115*06c3fb27SDimitry Andric case RISCV::MULW: 116*06c3fb27SDimitry Andric case RISCV::REMUW: 117*06c3fb27SDimitry Andric case RISCV::REMW: 118*06c3fb27SDimitry Andric case RISCV::SLLIW: 119*06c3fb27SDimitry Andric case RISCV::SLLW: 120*06c3fb27SDimitry Andric case RISCV::SRAIW: 121*06c3fb27SDimitry Andric case RISCV::SRAW: 122*06c3fb27SDimitry Andric case RISCV::SRLIW: 123*06c3fb27SDimitry Andric case RISCV::SRLW: 124*06c3fb27SDimitry Andric case RISCV::SUBW: 125*06c3fb27SDimitry Andric case RISCV::ROLW: 126*06c3fb27SDimitry Andric case RISCV::RORW: 127*06c3fb27SDimitry Andric case RISCV::RORIW: 128*06c3fb27SDimitry Andric case RISCV::CLZW: 129*06c3fb27SDimitry Andric case RISCV::CTZW: 130*06c3fb27SDimitry Andric case RISCV::CPOPW: 131*06c3fb27SDimitry Andric case RISCV::SLLI_UW: 132*06c3fb27SDimitry Andric case RISCV::FMV_W_X: 133*06c3fb27SDimitry Andric case RISCV::FCVT_H_W: 134*06c3fb27SDimitry Andric case RISCV::FCVT_H_WU: 135*06c3fb27SDimitry Andric case RISCV::FCVT_S_W: 136*06c3fb27SDimitry Andric case RISCV::FCVT_S_WU: 137*06c3fb27SDimitry Andric case RISCV::FCVT_D_W: 138*06c3fb27SDimitry Andric case RISCV::FCVT_D_WU: 139*06c3fb27SDimitry Andric if (Bits >= 32) 140*06c3fb27SDimitry Andric break; 141*06c3fb27SDimitry Andric return false; 142*06c3fb27SDimitry Andric case RISCV::SEXT_B: 143*06c3fb27SDimitry Andric case RISCV::PACKH: 144*06c3fb27SDimitry Andric if (Bits >= 8) 145*06c3fb27SDimitry Andric break; 146*06c3fb27SDimitry Andric return false; 147*06c3fb27SDimitry Andric case RISCV::SEXT_H: 148*06c3fb27SDimitry Andric case RISCV::FMV_H_X: 149*06c3fb27SDimitry Andric case RISCV::ZEXT_H_RV32: 150*06c3fb27SDimitry Andric case RISCV::ZEXT_H_RV64: 151*06c3fb27SDimitry Andric case RISCV::PACKW: 152*06c3fb27SDimitry Andric if (Bits >= 16) 153*06c3fb27SDimitry Andric break; 154*06c3fb27SDimitry Andric return false; 155*06c3fb27SDimitry Andric 156*06c3fb27SDimitry Andric case RISCV::PACK: 157*06c3fb27SDimitry Andric if (Bits >= (ST.getXLen() / 2)) 158*06c3fb27SDimitry Andric break; 159*06c3fb27SDimitry Andric return false; 160*06c3fb27SDimitry Andric 161*06c3fb27SDimitry Andric case RISCV::SRLI: { 162*06c3fb27SDimitry Andric // If we are shifting right by less than Bits, and users don't demand 163*06c3fb27SDimitry Andric // any bits that were shifted into [Bits-1:0], then we can consider this 164*06c3fb27SDimitry Andric // as an N-Bit user. 165*06c3fb27SDimitry Andric unsigned ShAmt = UserMI->getOperand(2).getImm(); 166*06c3fb27SDimitry Andric if (Bits > ShAmt) { 167*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 168*06c3fb27SDimitry Andric break; 169*06c3fb27SDimitry Andric } 170*06c3fb27SDimitry Andric return false; 171*06c3fb27SDimitry Andric } 172*06c3fb27SDimitry Andric 173*06c3fb27SDimitry Andric // these overwrite higher input bits, otherwise the lower word of output 174*06c3fb27SDimitry Andric // depends only on the lower word of input. So check their uses read W. 175*06c3fb27SDimitry Andric case RISCV::SLLI: 176*06c3fb27SDimitry Andric if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) 177*06c3fb27SDimitry Andric break; 178*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 179*06c3fb27SDimitry Andric break; 180*06c3fb27SDimitry Andric case RISCV::ANDI: { 181*06c3fb27SDimitry Andric uint64_t Imm = UserMI->getOperand(2).getImm(); 182*06c3fb27SDimitry Andric if (Bits >= (unsigned)llvm::bit_width(Imm)) 183*06c3fb27SDimitry Andric break; 184*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 185*06c3fb27SDimitry Andric break; 186*06c3fb27SDimitry Andric } 187*06c3fb27SDimitry Andric case RISCV::ORI: { 188*06c3fb27SDimitry Andric uint64_t Imm = UserMI->getOperand(2).getImm(); 189*06c3fb27SDimitry Andric if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 190*06c3fb27SDimitry Andric break; 191*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 192*06c3fb27SDimitry Andric break; 193*06c3fb27SDimitry Andric } 194*06c3fb27SDimitry Andric 195*06c3fb27SDimitry Andric case RISCV::SLL: 196*06c3fb27SDimitry Andric case RISCV::BSET: 197*06c3fb27SDimitry Andric case RISCV::BCLR: 198*06c3fb27SDimitry Andric case RISCV::BINV: 199*06c3fb27SDimitry Andric // Operand 2 is the shift amount which uses log2(xlen) bits. 200*06c3fb27SDimitry Andric if (OpIdx == 2) { 201*06c3fb27SDimitry Andric if (Bits >= Log2_32(ST.getXLen())) 202*06c3fb27SDimitry Andric break; 203*06c3fb27SDimitry Andric return false; 204*06c3fb27SDimitry Andric } 205*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 206*06c3fb27SDimitry Andric break; 207*06c3fb27SDimitry Andric 208*06c3fb27SDimitry Andric case RISCV::SRA: 209*06c3fb27SDimitry Andric case RISCV::SRL: 210*06c3fb27SDimitry Andric case RISCV::ROL: 211*06c3fb27SDimitry Andric case RISCV::ROR: 212*06c3fb27SDimitry Andric // Operand 2 is the shift amount which uses 6 bits. 213*06c3fb27SDimitry Andric if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 214*06c3fb27SDimitry Andric break; 215*06c3fb27SDimitry Andric return false; 216*06c3fb27SDimitry Andric 217*06c3fb27SDimitry Andric case RISCV::ADD_UW: 218*06c3fb27SDimitry Andric case RISCV::SH1ADD_UW: 219*06c3fb27SDimitry Andric case RISCV::SH2ADD_UW: 220*06c3fb27SDimitry Andric case RISCV::SH3ADD_UW: 221*06c3fb27SDimitry Andric // Operand 1 is implicitly zero extended. 222*06c3fb27SDimitry Andric if (OpIdx == 1 && Bits >= 32) 223*06c3fb27SDimitry Andric break; 224*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 225*06c3fb27SDimitry Andric break; 226*06c3fb27SDimitry Andric 227*06c3fb27SDimitry Andric case RISCV::BEXTI: 228*06c3fb27SDimitry Andric if (UserMI->getOperand(2).getImm() >= Bits) 229*06c3fb27SDimitry Andric return false; 230*06c3fb27SDimitry Andric break; 231*06c3fb27SDimitry Andric 232*06c3fb27SDimitry Andric case RISCV::SB: 233*06c3fb27SDimitry Andric // The first argument is the value to store. 234*06c3fb27SDimitry Andric if (OpIdx == 0 && Bits >= 8) 235*06c3fb27SDimitry Andric break; 236*06c3fb27SDimitry Andric return false; 237*06c3fb27SDimitry Andric case RISCV::SH: 238*06c3fb27SDimitry Andric // The first argument is the value to store. 239*06c3fb27SDimitry Andric if (OpIdx == 0 && Bits >= 16) 240*06c3fb27SDimitry Andric break; 241*06c3fb27SDimitry Andric return false; 242*06c3fb27SDimitry Andric case RISCV::SW: 243*06c3fb27SDimitry Andric // The first argument is the value to store. 244*06c3fb27SDimitry Andric if (OpIdx == 0 && Bits >= 32) 245*06c3fb27SDimitry Andric break; 246*06c3fb27SDimitry Andric return false; 247*06c3fb27SDimitry Andric 248*06c3fb27SDimitry Andric // For these, lower word of output in these operations, depends only on 249*06c3fb27SDimitry Andric // the lower word of input. So, we check all uses only read lower word. 250*06c3fb27SDimitry Andric case RISCV::COPY: 251*06c3fb27SDimitry Andric case RISCV::PHI: 252*06c3fb27SDimitry Andric 253*06c3fb27SDimitry Andric case RISCV::ADD: 254*06c3fb27SDimitry Andric case RISCV::ADDI: 255*06c3fb27SDimitry Andric case RISCV::AND: 256*06c3fb27SDimitry Andric case RISCV::MUL: 257*06c3fb27SDimitry Andric case RISCV::OR: 258*06c3fb27SDimitry Andric case RISCV::SUB: 259*06c3fb27SDimitry Andric case RISCV::XOR: 260*06c3fb27SDimitry Andric case RISCV::XORI: 261*06c3fb27SDimitry Andric 262*06c3fb27SDimitry Andric case RISCV::ANDN: 263*06c3fb27SDimitry Andric case RISCV::BREV8: 264*06c3fb27SDimitry Andric case RISCV::CLMUL: 265*06c3fb27SDimitry Andric case RISCV::ORC_B: 266*06c3fb27SDimitry Andric case RISCV::ORN: 267*06c3fb27SDimitry Andric case RISCV::SH1ADD: 268*06c3fb27SDimitry Andric case RISCV::SH2ADD: 269*06c3fb27SDimitry Andric case RISCV::SH3ADD: 270*06c3fb27SDimitry Andric case RISCV::XNOR: 271*06c3fb27SDimitry Andric case RISCV::BSETI: 272*06c3fb27SDimitry Andric case RISCV::BCLRI: 273*06c3fb27SDimitry Andric case RISCV::BINVI: 274*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 275*06c3fb27SDimitry Andric break; 276*06c3fb27SDimitry Andric 277*06c3fb27SDimitry Andric case RISCV::PseudoCCMOVGPR: 278*06c3fb27SDimitry Andric // Either operand 4 or operand 5 is returned by this instruction. If 279*06c3fb27SDimitry Andric // only the lower word of the result is used, then only the lower word 280*06c3fb27SDimitry Andric // of operand 4 and 5 is used. 281*06c3fb27SDimitry Andric if (OpIdx != 4 && OpIdx != 5) 282*06c3fb27SDimitry Andric return false; 283*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 284*06c3fb27SDimitry Andric break; 285*06c3fb27SDimitry Andric 286*06c3fb27SDimitry Andric case RISCV::VT_MASKC: 287*06c3fb27SDimitry Andric case RISCV::VT_MASKCN: 288*06c3fb27SDimitry Andric if (OpIdx != 1) 289*06c3fb27SDimitry Andric return false; 290*06c3fb27SDimitry Andric Worklist.push_back(std::make_pair(UserMI, Bits)); 291*06c3fb27SDimitry Andric break; 292*06c3fb27SDimitry Andric } 293*06c3fb27SDimitry Andric } 294*06c3fb27SDimitry Andric } 295*06c3fb27SDimitry Andric 296*06c3fb27SDimitry Andric return true; 297*06c3fb27SDimitry Andric } 298*06c3fb27SDimitry Andric 299*06c3fb27SDimitry Andric static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 300*06c3fb27SDimitry Andric const MachineRegisterInfo &MRI) { 301*06c3fb27SDimitry Andric return hasAllNBitUsers(OrigMI, ST, MRI, 32); 302*06c3fb27SDimitry Andric } 303*06c3fb27SDimitry Andric 304*06c3fb27SDimitry Andric // This function returns true if the machine instruction always outputs a value 305*06c3fb27SDimitry Andric // where bits 63:32 match bit 31. 306*06c3fb27SDimitry Andric static bool isSignExtendingOpW(const MachineInstr &MI, 307*06c3fb27SDimitry Andric const MachineRegisterInfo &MRI) { 308*06c3fb27SDimitry Andric uint64_t TSFlags = MI.getDesc().TSFlags; 309*06c3fb27SDimitry Andric 310*06c3fb27SDimitry Andric // Instructions that can be determined from opcode are marked in tablegen. 311*06c3fb27SDimitry Andric if (TSFlags & RISCVII::IsSignExtendingOpWMask) 312*06c3fb27SDimitry Andric return true; 313*06c3fb27SDimitry Andric 314*06c3fb27SDimitry Andric // Special cases that require checking operands. 315*06c3fb27SDimitry Andric switch (MI.getOpcode()) { 316*06c3fb27SDimitry Andric // shifting right sufficiently makes the value 32-bit sign-extended 317*06c3fb27SDimitry Andric case RISCV::SRAI: 318*06c3fb27SDimitry Andric return MI.getOperand(2).getImm() >= 32; 319*06c3fb27SDimitry Andric case RISCV::SRLI: 320*06c3fb27SDimitry Andric return MI.getOperand(2).getImm() > 32; 321*06c3fb27SDimitry Andric // The LI pattern ADDI rd, X0, imm is sign extended. 322*06c3fb27SDimitry Andric case RISCV::ADDI: 323*06c3fb27SDimitry Andric return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 324*06c3fb27SDimitry Andric // An ANDI with an 11 bit immediate will zero bits 63:11. 325*06c3fb27SDimitry Andric case RISCV::ANDI: 326*06c3fb27SDimitry Andric return isUInt<11>(MI.getOperand(2).getImm()); 327*06c3fb27SDimitry Andric // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 328*06c3fb27SDimitry Andric case RISCV::ORI: 329*06c3fb27SDimitry Andric return !isUInt<11>(MI.getOperand(2).getImm()); 330*06c3fb27SDimitry Andric // Copying from X0 produces zero. 331*06c3fb27SDimitry Andric case RISCV::COPY: 332*06c3fb27SDimitry Andric return MI.getOperand(1).getReg() == RISCV::X0; 333*06c3fb27SDimitry Andric } 334*06c3fb27SDimitry Andric 335*06c3fb27SDimitry Andric return false; 336*06c3fb27SDimitry Andric } 337*06c3fb27SDimitry Andric 338*06c3fb27SDimitry Andric static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 339*06c3fb27SDimitry Andric const MachineRegisterInfo &MRI, 340*06c3fb27SDimitry Andric SmallPtrSetImpl<MachineInstr *> &FixableDef) { 341*06c3fb27SDimitry Andric 342*06c3fb27SDimitry Andric SmallPtrSet<const MachineInstr *, 4> Visited; 343*06c3fb27SDimitry Andric SmallVector<MachineInstr *, 4> Worklist; 344*06c3fb27SDimitry Andric 345*06c3fb27SDimitry Andric auto AddRegDefToWorkList = [&](Register SrcReg) { 346*06c3fb27SDimitry Andric if (!SrcReg.isVirtual()) 347*06c3fb27SDimitry Andric return false; 348*06c3fb27SDimitry Andric MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 349*06c3fb27SDimitry Andric if (!SrcMI) 350*06c3fb27SDimitry Andric return false; 351*06c3fb27SDimitry Andric // Add SrcMI to the worklist. 352*06c3fb27SDimitry Andric Worklist.push_back(SrcMI); 353*06c3fb27SDimitry Andric return true; 354*06c3fb27SDimitry Andric }; 355*06c3fb27SDimitry Andric 356*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(SrcReg)) 357*06c3fb27SDimitry Andric return false; 358*06c3fb27SDimitry Andric 359*06c3fb27SDimitry Andric while (!Worklist.empty()) { 360*06c3fb27SDimitry Andric MachineInstr *MI = Worklist.pop_back_val(); 361*06c3fb27SDimitry Andric 362*06c3fb27SDimitry Andric // If we already visited this instruction, we don't need to check it again. 363*06c3fb27SDimitry Andric if (!Visited.insert(MI).second) 364*06c3fb27SDimitry Andric continue; 365*06c3fb27SDimitry Andric 366*06c3fb27SDimitry Andric // If this is a sign extending operation we don't need to look any further. 367*06c3fb27SDimitry Andric if (isSignExtendingOpW(*MI, MRI)) 368*06c3fb27SDimitry Andric continue; 369*06c3fb27SDimitry Andric 370*06c3fb27SDimitry Andric // Is this an instruction that propagates sign extend? 371*06c3fb27SDimitry Andric switch (MI->getOpcode()) { 372*06c3fb27SDimitry Andric default: 373*06c3fb27SDimitry Andric // Unknown opcode, give up. 374*06c3fb27SDimitry Andric return false; 375*06c3fb27SDimitry Andric case RISCV::COPY: { 376*06c3fb27SDimitry Andric const MachineFunction *MF = MI->getMF(); 377*06c3fb27SDimitry Andric const RISCVMachineFunctionInfo *RVFI = 378*06c3fb27SDimitry Andric MF->getInfo<RISCVMachineFunctionInfo>(); 379*06c3fb27SDimitry Andric 380*06c3fb27SDimitry Andric // If this is the entry block and the register is livein, see if we know 381*06c3fb27SDimitry Andric // it is sign extended. 382*06c3fb27SDimitry Andric if (MI->getParent() == &MF->front()) { 383*06c3fb27SDimitry Andric Register VReg = MI->getOperand(0).getReg(); 384*06c3fb27SDimitry Andric if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 385*06c3fb27SDimitry Andric continue; 386*06c3fb27SDimitry Andric } 387*06c3fb27SDimitry Andric 388*06c3fb27SDimitry Andric Register CopySrcReg = MI->getOperand(1).getReg(); 389*06c3fb27SDimitry Andric if (CopySrcReg == RISCV::X10) { 390*06c3fb27SDimitry Andric // For a method return value, we check the ZExt/SExt flags in attribute. 391*06c3fb27SDimitry Andric // We assume the following code sequence for method call. 392*06c3fb27SDimitry Andric // PseudoCALL @bar, ... 393*06c3fb27SDimitry Andric // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 394*06c3fb27SDimitry Andric // %0:gpr = COPY $x10 395*06c3fb27SDimitry Andric // 396*06c3fb27SDimitry Andric // We use the PseudoCall to look up the IR function being called to find 397*06c3fb27SDimitry Andric // its return attributes. 398*06c3fb27SDimitry Andric const MachineBasicBlock *MBB = MI->getParent(); 399*06c3fb27SDimitry Andric auto II = MI->getIterator(); 400*06c3fb27SDimitry Andric if (II == MBB->instr_begin() || 401*06c3fb27SDimitry Andric (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 402*06c3fb27SDimitry Andric return false; 403*06c3fb27SDimitry Andric 404*06c3fb27SDimitry Andric const MachineInstr &CallMI = *(--II); 405*06c3fb27SDimitry Andric if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 406*06c3fb27SDimitry Andric return false; 407*06c3fb27SDimitry Andric 408*06c3fb27SDimitry Andric auto *CalleeFn = 409*06c3fb27SDimitry Andric dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 410*06c3fb27SDimitry Andric if (!CalleeFn) 411*06c3fb27SDimitry Andric return false; 412*06c3fb27SDimitry Andric 413*06c3fb27SDimitry Andric auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 414*06c3fb27SDimitry Andric if (!IntTy) 415*06c3fb27SDimitry Andric return false; 416*06c3fb27SDimitry Andric 417*06c3fb27SDimitry Andric const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 418*06c3fb27SDimitry Andric unsigned BitWidth = IntTy->getBitWidth(); 419*06c3fb27SDimitry Andric if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 420*06c3fb27SDimitry Andric (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 421*06c3fb27SDimitry Andric continue; 422*06c3fb27SDimitry Andric } 423*06c3fb27SDimitry Andric 424*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(CopySrcReg)) 425*06c3fb27SDimitry Andric return false; 426*06c3fb27SDimitry Andric 427*06c3fb27SDimitry Andric break; 428*06c3fb27SDimitry Andric } 429*06c3fb27SDimitry Andric 430*06c3fb27SDimitry Andric // For these, we just need to check if the 1st operand is sign extended. 431*06c3fb27SDimitry Andric case RISCV::BCLRI: 432*06c3fb27SDimitry Andric case RISCV::BINVI: 433*06c3fb27SDimitry Andric case RISCV::BSETI: 434*06c3fb27SDimitry Andric if (MI->getOperand(2).getImm() >= 31) 435*06c3fb27SDimitry Andric return false; 436*06c3fb27SDimitry Andric [[fallthrough]]; 437*06c3fb27SDimitry Andric case RISCV::REM: 438*06c3fb27SDimitry Andric case RISCV::ANDI: 439*06c3fb27SDimitry Andric case RISCV::ORI: 440*06c3fb27SDimitry Andric case RISCV::XORI: 441*06c3fb27SDimitry Andric // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 442*06c3fb27SDimitry Andric // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 443*06c3fb27SDimitry Andric // Logical operations use a sign extended 12-bit immediate. 444*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 445*06c3fb27SDimitry Andric return false; 446*06c3fb27SDimitry Andric 447*06c3fb27SDimitry Andric break; 448*06c3fb27SDimitry Andric case RISCV::PseudoCCADDW: 449*06c3fb27SDimitry Andric case RISCV::PseudoCCSUBW: 450*06c3fb27SDimitry Andric // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to 451*06c3fb27SDimitry Andric // check if operand 4 is sign extended. 452*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(MI->getOperand(4).getReg())) 453*06c3fb27SDimitry Andric return false; 454*06c3fb27SDimitry Andric break; 455*06c3fb27SDimitry Andric case RISCV::REMU: 456*06c3fb27SDimitry Andric case RISCV::AND: 457*06c3fb27SDimitry Andric case RISCV::OR: 458*06c3fb27SDimitry Andric case RISCV::XOR: 459*06c3fb27SDimitry Andric case RISCV::ANDN: 460*06c3fb27SDimitry Andric case RISCV::ORN: 461*06c3fb27SDimitry Andric case RISCV::XNOR: 462*06c3fb27SDimitry Andric case RISCV::MAX: 463*06c3fb27SDimitry Andric case RISCV::MAXU: 464*06c3fb27SDimitry Andric case RISCV::MIN: 465*06c3fb27SDimitry Andric case RISCV::MINU: 466*06c3fb27SDimitry Andric case RISCV::PseudoCCMOVGPR: 467*06c3fb27SDimitry Andric case RISCV::PseudoCCAND: 468*06c3fb27SDimitry Andric case RISCV::PseudoCCOR: 469*06c3fb27SDimitry Andric case RISCV::PseudoCCXOR: 470*06c3fb27SDimitry Andric case RISCV::PHI: { 471*06c3fb27SDimitry Andric // If all incoming values are sign-extended, the output of AND, OR, XOR, 472*06c3fb27SDimitry Andric // MIN, MAX, or PHI is also sign-extended. 473*06c3fb27SDimitry Andric 474*06c3fb27SDimitry Andric // The input registers for PHI are operand 1, 3, ... 475*06c3fb27SDimitry Andric // The input registers for PseudoCCMOVGPR are 4 and 5. 476*06c3fb27SDimitry Andric // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 477*06c3fb27SDimitry Andric // The input registers for others are operand 1 and 2. 478*06c3fb27SDimitry Andric unsigned B = 1, E = 3, D = 1; 479*06c3fb27SDimitry Andric switch (MI->getOpcode()) { 480*06c3fb27SDimitry Andric case RISCV::PHI: 481*06c3fb27SDimitry Andric E = MI->getNumOperands(); 482*06c3fb27SDimitry Andric D = 2; 483*06c3fb27SDimitry Andric break; 484*06c3fb27SDimitry Andric case RISCV::PseudoCCMOVGPR: 485*06c3fb27SDimitry Andric B = 4; 486*06c3fb27SDimitry Andric E = 6; 487*06c3fb27SDimitry Andric break; 488*06c3fb27SDimitry Andric case RISCV::PseudoCCAND: 489*06c3fb27SDimitry Andric case RISCV::PseudoCCOR: 490*06c3fb27SDimitry Andric case RISCV::PseudoCCXOR: 491*06c3fb27SDimitry Andric B = 4; 492*06c3fb27SDimitry Andric E = 7; 493*06c3fb27SDimitry Andric break; 494*06c3fb27SDimitry Andric } 495*06c3fb27SDimitry Andric 496*06c3fb27SDimitry Andric for (unsigned I = B; I != E; I += D) { 497*06c3fb27SDimitry Andric if (!MI->getOperand(I).isReg()) 498*06c3fb27SDimitry Andric return false; 499*06c3fb27SDimitry Andric 500*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(MI->getOperand(I).getReg())) 501*06c3fb27SDimitry Andric return false; 502*06c3fb27SDimitry Andric } 503*06c3fb27SDimitry Andric 504*06c3fb27SDimitry Andric break; 505*06c3fb27SDimitry Andric } 506*06c3fb27SDimitry Andric 507*06c3fb27SDimitry Andric case RISCV::VT_MASKC: 508*06c3fb27SDimitry Andric case RISCV::VT_MASKCN: 509*06c3fb27SDimitry Andric // Instructions return zero or operand 1. Result is sign extended if 510*06c3fb27SDimitry Andric // operand 1 is sign extended. 511*06c3fb27SDimitry Andric if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 512*06c3fb27SDimitry Andric return false; 513*06c3fb27SDimitry Andric break; 514*06c3fb27SDimitry Andric 515*06c3fb27SDimitry Andric // With these opcode, we can "fix" them with the W-version 516*06c3fb27SDimitry Andric // if we know all users of the result only rely on bits 31:0 517*06c3fb27SDimitry Andric case RISCV::SLLI: 518*06c3fb27SDimitry Andric // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 519*06c3fb27SDimitry Andric if (MI->getOperand(2).getImm() >= 32) 520*06c3fb27SDimitry Andric return false; 521*06c3fb27SDimitry Andric [[fallthrough]]; 522*06c3fb27SDimitry Andric case RISCV::ADDI: 523*06c3fb27SDimitry Andric case RISCV::ADD: 524*06c3fb27SDimitry Andric case RISCV::LD: 525*06c3fb27SDimitry Andric case RISCV::LWU: 526*06c3fb27SDimitry Andric case RISCV::MUL: 527*06c3fb27SDimitry Andric case RISCV::SUB: 528*06c3fb27SDimitry Andric if (hasAllWUsers(*MI, ST, MRI)) { 529*06c3fb27SDimitry Andric FixableDef.insert(MI); 530*06c3fb27SDimitry Andric break; 531*06c3fb27SDimitry Andric } 532*06c3fb27SDimitry Andric return false; 533*06c3fb27SDimitry Andric } 534*06c3fb27SDimitry Andric } 535*06c3fb27SDimitry Andric 536*06c3fb27SDimitry Andric // If we get here, then every node we visited produces a sign extended value 537*06c3fb27SDimitry Andric // or propagated sign extended values. So the result must be sign extended. 538*06c3fb27SDimitry Andric return true; 539*06c3fb27SDimitry Andric } 540*06c3fb27SDimitry Andric 541*06c3fb27SDimitry Andric static unsigned getWOp(unsigned Opcode) { 542*06c3fb27SDimitry Andric switch (Opcode) { 543*06c3fb27SDimitry Andric case RISCV::ADDI: 544*06c3fb27SDimitry Andric return RISCV::ADDIW; 545*06c3fb27SDimitry Andric case RISCV::ADD: 546*06c3fb27SDimitry Andric return RISCV::ADDW; 547*06c3fb27SDimitry Andric case RISCV::LD: 548*06c3fb27SDimitry Andric case RISCV::LWU: 549*06c3fb27SDimitry Andric return RISCV::LW; 550*06c3fb27SDimitry Andric case RISCV::MUL: 551*06c3fb27SDimitry Andric return RISCV::MULW; 552*06c3fb27SDimitry Andric case RISCV::SLLI: 553*06c3fb27SDimitry Andric return RISCV::SLLIW; 554*06c3fb27SDimitry Andric case RISCV::SUB: 555*06c3fb27SDimitry Andric return RISCV::SUBW; 556*06c3fb27SDimitry Andric default: 557*06c3fb27SDimitry Andric llvm_unreachable("Unexpected opcode for replacement with W variant"); 558*06c3fb27SDimitry Andric } 559*06c3fb27SDimitry Andric } 560*06c3fb27SDimitry Andric 561*06c3fb27SDimitry Andric bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 562*06c3fb27SDimitry Andric const RISCVInstrInfo &TII, 563*06c3fb27SDimitry Andric const RISCVSubtarget &ST, 564*06c3fb27SDimitry Andric MachineRegisterInfo &MRI) { 565*06c3fb27SDimitry Andric if (DisableSExtWRemoval) 566*06c3fb27SDimitry Andric return false; 567*06c3fb27SDimitry Andric 568*06c3fb27SDimitry Andric bool MadeChange = false; 569*06c3fb27SDimitry Andric for (MachineBasicBlock &MBB : MF) { 570*06c3fb27SDimitry Andric for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { 571*06c3fb27SDimitry Andric MachineInstr *MI = &*I++; 572*06c3fb27SDimitry Andric 573*06c3fb27SDimitry Andric // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 574*06c3fb27SDimitry Andric if (!RISCV::isSEXT_W(*MI)) 575*06c3fb27SDimitry Andric continue; 576*06c3fb27SDimitry Andric 577*06c3fb27SDimitry Andric Register SrcReg = MI->getOperand(1).getReg(); 578*06c3fb27SDimitry Andric 579*06c3fb27SDimitry Andric SmallPtrSet<MachineInstr *, 4> FixableDefs; 580*06c3fb27SDimitry Andric 581*06c3fb27SDimitry Andric // If all users only use the lower bits, this sext.w is redundant. 582*06c3fb27SDimitry Andric // Or if all definitions reaching MI sign-extend their output, 583*06c3fb27SDimitry Andric // then sext.w is redundant. 584*06c3fb27SDimitry Andric if (!hasAllWUsers(*MI, ST, MRI) && 585*06c3fb27SDimitry Andric !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 586*06c3fb27SDimitry Andric continue; 587*06c3fb27SDimitry Andric 588*06c3fb27SDimitry Andric Register DstReg = MI->getOperand(0).getReg(); 589*06c3fb27SDimitry Andric if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 590*06c3fb27SDimitry Andric continue; 591*06c3fb27SDimitry Andric 592*06c3fb27SDimitry Andric // Convert Fixable instructions to their W versions. 593*06c3fb27SDimitry Andric for (MachineInstr *Fixable : FixableDefs) { 594*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 595*06c3fb27SDimitry Andric Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 596*06c3fb27SDimitry Andric Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 597*06c3fb27SDimitry Andric Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 598*06c3fb27SDimitry Andric Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 599*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << " with " << *Fixable); 600*06c3fb27SDimitry Andric ++NumTransformedToWInstrs; 601*06c3fb27SDimitry Andric } 602*06c3fb27SDimitry Andric 603*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 604*06c3fb27SDimitry Andric MRI.replaceRegWith(DstReg, SrcReg); 605*06c3fb27SDimitry Andric MRI.clearKillFlags(SrcReg); 606*06c3fb27SDimitry Andric MI->eraseFromParent(); 607*06c3fb27SDimitry Andric ++NumRemovedSExtW; 608*06c3fb27SDimitry Andric MadeChange = true; 609*06c3fb27SDimitry Andric } 610*06c3fb27SDimitry Andric } 611*06c3fb27SDimitry Andric 612*06c3fb27SDimitry Andric return MadeChange; 613*06c3fb27SDimitry Andric } 614*06c3fb27SDimitry Andric 615*06c3fb27SDimitry Andric bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 616*06c3fb27SDimitry Andric const RISCVInstrInfo &TII, 617*06c3fb27SDimitry Andric const RISCVSubtarget &ST, 618*06c3fb27SDimitry Andric MachineRegisterInfo &MRI) { 619*06c3fb27SDimitry Andric if (DisableStripWSuffix) 620*06c3fb27SDimitry Andric return false; 621*06c3fb27SDimitry Andric 622*06c3fb27SDimitry Andric bool MadeChange = false; 623*06c3fb27SDimitry Andric for (MachineBasicBlock &MBB : MF) { 624*06c3fb27SDimitry Andric for (auto I = MBB.begin(), IE = MBB.end(); I != IE; ++I) { 625*06c3fb27SDimitry Andric MachineInstr &MI = *I; 626*06c3fb27SDimitry Andric 627*06c3fb27SDimitry Andric unsigned Opc; 628*06c3fb27SDimitry Andric switch (MI.getOpcode()) { 629*06c3fb27SDimitry Andric default: 630*06c3fb27SDimitry Andric continue; 631*06c3fb27SDimitry Andric case RISCV::ADDW: Opc = RISCV::ADD; break; 632*06c3fb27SDimitry Andric case RISCV::MULW: Opc = RISCV::MUL; break; 633*06c3fb27SDimitry Andric case RISCV::SLLIW: Opc = RISCV::SLLI; break; 634*06c3fb27SDimitry Andric } 635*06c3fb27SDimitry Andric 636*06c3fb27SDimitry Andric if (hasAllWUsers(MI, ST, MRI)) { 637*06c3fb27SDimitry Andric MI.setDesc(TII.get(Opc)); 638*06c3fb27SDimitry Andric MadeChange = true; 639*06c3fb27SDimitry Andric } 640*06c3fb27SDimitry Andric } 641*06c3fb27SDimitry Andric } 642*06c3fb27SDimitry Andric 643*06c3fb27SDimitry Andric return MadeChange; 644*06c3fb27SDimitry Andric } 645*06c3fb27SDimitry Andric 646*06c3fb27SDimitry Andric bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 647*06c3fb27SDimitry Andric if (skipFunction(MF.getFunction())) 648*06c3fb27SDimitry Andric return false; 649*06c3fb27SDimitry Andric 650*06c3fb27SDimitry Andric MachineRegisterInfo &MRI = MF.getRegInfo(); 651*06c3fb27SDimitry Andric const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 652*06c3fb27SDimitry Andric const RISCVInstrInfo &TII = *ST.getInstrInfo(); 653*06c3fb27SDimitry Andric 654*06c3fb27SDimitry Andric if (!ST.is64Bit()) 655*06c3fb27SDimitry Andric return false; 656*06c3fb27SDimitry Andric 657*06c3fb27SDimitry Andric bool MadeChange = false; 658*06c3fb27SDimitry Andric MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 659*06c3fb27SDimitry Andric MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 660*06c3fb27SDimitry Andric 661*06c3fb27SDimitry Andric return MadeChange; 662*06c3fb27SDimitry Andric } 663