1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass does some optimizations for *W instructions at the MI level. 10 // 11 // First it removes unneeded sext.w instructions. Either because the sign 12 // extended bits aren't consumed or because the input was already sign extended 13 // by an earlier instruction. 14 // 15 // Then: 16 // 1. Unless explicit disabled or the target prefers instructions with W suffix, 17 // it removes the -w suffix from opw instructions whenever all users are 18 // dependent only on the lower word of the result of the instruction. 19 // The cases handled are: 20 // * addw because c.add has a larger register encoding than c.addw. 21 // * addiw because it helps reduce test differences between RV32 and RV64 22 // w/o being a pessimization. 23 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) 24 // * slliw because c.slliw doesn't exist and c.slli does 25 // 26 // 2. Or if explicit enabled or the target prefers instructions with W suffix, 27 // it adds the W suffix to the instruction whenever all users are dependent 28 // only on the lower word of the result of the instruction. 29 // The cases handled are: 30 // * add/addi/sub/mul. 31 // * slli with imm < 32. 32 // * ld/lwu. 33 //===---------------------------------------------------------------------===// 34 35 #include "RISCV.h" 36 #include "RISCVMachineFunctionInfo.h" 37 #include "RISCVSubtarget.h" 38 #include "llvm/ADT/SmallSet.h" 39 #include "llvm/ADT/Statistic.h" 40 #include "llvm/CodeGen/MachineFunctionPass.h" 41 #include "llvm/CodeGen/TargetInstrInfo.h" 42 43 using namespace llvm; 44 45 #define DEBUG_TYPE "riscv-opt-w-instrs" 46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 47 48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 49 STATISTIC(NumTransformedToWInstrs, 50 "Number of instructions transformed to W-ops"); 51 52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 53 cl::desc("Disable removal of sext.w"), 54 cl::init(false), cl::Hidden); 55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 56 cl::desc("Disable strip W suffix"), 57 cl::init(false), cl::Hidden); 58 59 namespace { 60 61 class RISCVOptWInstrs : public MachineFunctionPass { 62 public: 63 static char ID; 64 65 RISCVOptWInstrs() : MachineFunctionPass(ID) {} 66 67 bool runOnMachineFunction(MachineFunction &MF) override; 68 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 69 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 70 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 71 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 72 bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 73 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 74 75 void getAnalysisUsage(AnalysisUsage &AU) const override { 76 AU.setPreservesCFG(); 77 MachineFunctionPass::getAnalysisUsage(AU); 78 } 79 80 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 81 }; 82 83 } // end anonymous namespace 84 85 char RISCVOptWInstrs::ID = 0; 86 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 87 false) 88 89 FunctionPass *llvm::createRISCVOptWInstrsPass() { 90 return new RISCVOptWInstrs(); 91 } 92 93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, 94 unsigned Bits) { 95 const MachineInstr &MI = *UserOp.getParent(); 96 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); 97 98 if (!MCOpcode) 99 return false; 100 101 const MCInstrDesc &MCID = MI.getDesc(); 102 const uint64_t TSFlags = MCID.TSFlags; 103 if (!RISCVII::hasSEWOp(TSFlags)) 104 return false; 105 assert(RISCVII::hasVLOp(TSFlags)); 106 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm(); 107 108 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID)) 109 return false; 110 111 auto NumDemandedBits = 112 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); 113 return NumDemandedBits && Bits >= *NumDemandedBits; 114 } 115 116 // Checks if all users only demand the lower \p OrigBits of the original 117 // instruction's result. 118 // TODO: handle multiple interdependent transformations 119 static bool hasAllNBitUsers(const MachineInstr &OrigMI, 120 const RISCVSubtarget &ST, 121 const MachineRegisterInfo &MRI, unsigned OrigBits) { 122 123 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 124 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 125 126 Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 127 128 while (!Worklist.empty()) { 129 auto P = Worklist.pop_back_val(); 130 const MachineInstr *MI = P.first; 131 unsigned Bits = P.second; 132 133 if (!Visited.insert(P).second) 134 continue; 135 136 // Only handle instructions with one def. 137 if (MI->getNumExplicitDefs() != 1) 138 return false; 139 140 Register DestReg = MI->getOperand(0).getReg(); 141 if (!DestReg.isVirtual()) 142 return false; 143 144 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) { 145 const MachineInstr *UserMI = UserOp.getParent(); 146 unsigned OpIdx = UserOp.getOperandNo(); 147 148 switch (UserMI->getOpcode()) { 149 default: 150 if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) 151 break; 152 return false; 153 154 case RISCV::ADDIW: 155 case RISCV::ADDW: 156 case RISCV::DIVUW: 157 case RISCV::DIVW: 158 case RISCV::MULW: 159 case RISCV::REMUW: 160 case RISCV::REMW: 161 case RISCV::SLLIW: 162 case RISCV::SLLW: 163 case RISCV::SRAIW: 164 case RISCV::SRAW: 165 case RISCV::SRLIW: 166 case RISCV::SRLW: 167 case RISCV::SUBW: 168 case RISCV::ROLW: 169 case RISCV::RORW: 170 case RISCV::RORIW: 171 case RISCV::CLZW: 172 case RISCV::CTZW: 173 case RISCV::CPOPW: 174 case RISCV::SLLI_UW: 175 case RISCV::FMV_W_X: 176 case RISCV::FCVT_H_W: 177 case RISCV::FCVT_H_W_INX: 178 case RISCV::FCVT_H_WU: 179 case RISCV::FCVT_H_WU_INX: 180 case RISCV::FCVT_S_W: 181 case RISCV::FCVT_S_W_INX: 182 case RISCV::FCVT_S_WU: 183 case RISCV::FCVT_S_WU_INX: 184 case RISCV::FCVT_D_W: 185 case RISCV::FCVT_D_W_INX: 186 case RISCV::FCVT_D_WU: 187 case RISCV::FCVT_D_WU_INX: 188 if (Bits >= 32) 189 break; 190 return false; 191 case RISCV::SEXT_B: 192 case RISCV::PACKH: 193 if (Bits >= 8) 194 break; 195 return false; 196 case RISCV::SEXT_H: 197 case RISCV::FMV_H_X: 198 case RISCV::ZEXT_H_RV32: 199 case RISCV::ZEXT_H_RV64: 200 case RISCV::PACKW: 201 if (Bits >= 16) 202 break; 203 return false; 204 205 case RISCV::PACK: 206 if (Bits >= (ST.getXLen() / 2)) 207 break; 208 return false; 209 210 case RISCV::SRLI: { 211 // If we are shifting right by less than Bits, and users don't demand 212 // any bits that were shifted into [Bits-1:0], then we can consider this 213 // as an N-Bit user. 214 unsigned ShAmt = UserMI->getOperand(2).getImm(); 215 if (Bits > ShAmt) { 216 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 217 break; 218 } 219 return false; 220 } 221 222 // these overwrite higher input bits, otherwise the lower word of output 223 // depends only on the lower word of input. So check their uses read W. 224 case RISCV::SLLI: { 225 unsigned ShAmt = UserMI->getOperand(2).getImm(); 226 if (Bits >= (ST.getXLen() - ShAmt)) 227 break; 228 Worklist.push_back(std::make_pair(UserMI, Bits + ShAmt)); 229 break; 230 } 231 case RISCV::ANDI: { 232 uint64_t Imm = UserMI->getOperand(2).getImm(); 233 if (Bits >= (unsigned)llvm::bit_width(Imm)) 234 break; 235 Worklist.push_back(std::make_pair(UserMI, Bits)); 236 break; 237 } 238 case RISCV::ORI: { 239 uint64_t Imm = UserMI->getOperand(2).getImm(); 240 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 241 break; 242 Worklist.push_back(std::make_pair(UserMI, Bits)); 243 break; 244 } 245 246 case RISCV::SLL: 247 case RISCV::BSET: 248 case RISCV::BCLR: 249 case RISCV::BINV: 250 // Operand 2 is the shift amount which uses log2(xlen) bits. 251 if (OpIdx == 2) { 252 if (Bits >= Log2_32(ST.getXLen())) 253 break; 254 return false; 255 } 256 Worklist.push_back(std::make_pair(UserMI, Bits)); 257 break; 258 259 case RISCV::SRA: 260 case RISCV::SRL: 261 case RISCV::ROL: 262 case RISCV::ROR: 263 // Operand 2 is the shift amount which uses 6 bits. 264 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 265 break; 266 return false; 267 268 case RISCV::ADD_UW: 269 case RISCV::SH1ADD_UW: 270 case RISCV::SH2ADD_UW: 271 case RISCV::SH3ADD_UW: 272 // Operand 1 is implicitly zero extended. 273 if (OpIdx == 1 && Bits >= 32) 274 break; 275 Worklist.push_back(std::make_pair(UserMI, Bits)); 276 break; 277 278 case RISCV::BEXTI: 279 if (UserMI->getOperand(2).getImm() >= Bits) 280 return false; 281 break; 282 283 case RISCV::SB: 284 // The first argument is the value to store. 285 if (OpIdx == 0 && Bits >= 8) 286 break; 287 return false; 288 case RISCV::SH: 289 // The first argument is the value to store. 290 if (OpIdx == 0 && Bits >= 16) 291 break; 292 return false; 293 case RISCV::SW: 294 // The first argument is the value to store. 295 if (OpIdx == 0 && Bits >= 32) 296 break; 297 return false; 298 299 // For these, lower word of output in these operations, depends only on 300 // the lower word of input. So, we check all uses only read lower word. 301 case RISCV::COPY: 302 case RISCV::PHI: 303 304 case RISCV::ADD: 305 case RISCV::ADDI: 306 case RISCV::AND: 307 case RISCV::MUL: 308 case RISCV::OR: 309 case RISCV::SUB: 310 case RISCV::XOR: 311 case RISCV::XORI: 312 313 case RISCV::ANDN: 314 case RISCV::BREV8: 315 case RISCV::CLMUL: 316 case RISCV::ORC_B: 317 case RISCV::ORN: 318 case RISCV::SH1ADD: 319 case RISCV::SH2ADD: 320 case RISCV::SH3ADD: 321 case RISCV::XNOR: 322 case RISCV::BSETI: 323 case RISCV::BCLRI: 324 case RISCV::BINVI: 325 Worklist.push_back(std::make_pair(UserMI, Bits)); 326 break; 327 328 case RISCV::PseudoCCMOVGPR: 329 // Either operand 4 or operand 5 is returned by this instruction. If 330 // only the lower word of the result is used, then only the lower word 331 // of operand 4 and 5 is used. 332 if (OpIdx != 4 && OpIdx != 5) 333 return false; 334 Worklist.push_back(std::make_pair(UserMI, Bits)); 335 break; 336 337 case RISCV::CZERO_EQZ: 338 case RISCV::CZERO_NEZ: 339 case RISCV::VT_MASKC: 340 case RISCV::VT_MASKCN: 341 if (OpIdx != 1) 342 return false; 343 Worklist.push_back(std::make_pair(UserMI, Bits)); 344 break; 345 } 346 } 347 } 348 349 return true; 350 } 351 352 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 353 const MachineRegisterInfo &MRI) { 354 return hasAllNBitUsers(OrigMI, ST, MRI, 32); 355 } 356 357 // This function returns true if the machine instruction always outputs a value 358 // where bits 63:32 match bit 31. 359 static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) { 360 uint64_t TSFlags = MI.getDesc().TSFlags; 361 362 // Instructions that can be determined from opcode are marked in tablegen. 363 if (TSFlags & RISCVII::IsSignExtendingOpWMask) 364 return true; 365 366 // Special cases that require checking operands. 367 switch (MI.getOpcode()) { 368 // shifting right sufficiently makes the value 32-bit sign-extended 369 case RISCV::SRAI: 370 return MI.getOperand(2).getImm() >= 32; 371 case RISCV::SRLI: 372 return MI.getOperand(2).getImm() > 32; 373 // The LI pattern ADDI rd, X0, imm is sign extended. 374 case RISCV::ADDI: 375 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 376 // An ANDI with an 11 bit immediate will zero bits 63:11. 377 case RISCV::ANDI: 378 return isUInt<11>(MI.getOperand(2).getImm()); 379 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 380 case RISCV::ORI: 381 return !isUInt<11>(MI.getOperand(2).getImm()); 382 // A bseti with X0 is sign extended if the immediate is less than 31. 383 case RISCV::BSETI: 384 return MI.getOperand(2).getImm() < 31 && 385 MI.getOperand(1).getReg() == RISCV::X0; 386 // Copying from X0 produces zero. 387 case RISCV::COPY: 388 return MI.getOperand(1).getReg() == RISCV::X0; 389 // Ignore the scratch register destination. 390 case RISCV::PseudoAtomicLoadNand32: 391 return OpNo == 0; 392 case RISCV::PseudoVMV_X_S: { 393 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. 394 int64_t Log2SEW = MI.getOperand(2).getImm(); 395 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW"); 396 return Log2SEW <= 5; 397 } 398 } 399 400 return false; 401 } 402 403 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 404 const MachineRegisterInfo &MRI, 405 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 406 SmallSet<Register, 4> Visited; 407 SmallVector<Register, 4> Worklist; 408 409 auto AddRegToWorkList = [&](Register SrcReg) { 410 if (!SrcReg.isVirtual()) 411 return false; 412 Worklist.push_back(SrcReg); 413 return true; 414 }; 415 416 if (!AddRegToWorkList(SrcReg)) 417 return false; 418 419 while (!Worklist.empty()) { 420 Register Reg = Worklist.pop_back_val(); 421 422 // If we already visited this register, we don't need to check it again. 423 if (!Visited.insert(Reg).second) 424 continue; 425 426 MachineInstr *MI = MRI.getVRegDef(Reg); 427 if (!MI) 428 continue; 429 430 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr); 431 assert(OpNo != -1 && "Couldn't find register"); 432 433 // If this is a sign extending operation we don't need to look any further. 434 if (isSignExtendingOpW(*MI, OpNo)) 435 continue; 436 437 // Is this an instruction that propagates sign extend? 438 switch (MI->getOpcode()) { 439 default: 440 // Unknown opcode, give up. 441 return false; 442 case RISCV::COPY: { 443 const MachineFunction *MF = MI->getMF(); 444 const RISCVMachineFunctionInfo *RVFI = 445 MF->getInfo<RISCVMachineFunctionInfo>(); 446 447 // If this is the entry block and the register is livein, see if we know 448 // it is sign extended. 449 if (MI->getParent() == &MF->front()) { 450 Register VReg = MI->getOperand(0).getReg(); 451 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 452 continue; 453 } 454 455 Register CopySrcReg = MI->getOperand(1).getReg(); 456 if (CopySrcReg == RISCV::X10) { 457 // For a method return value, we check the ZExt/SExt flags in attribute. 458 // We assume the following code sequence for method call. 459 // PseudoCALL @bar, ... 460 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 461 // %0:gpr = COPY $x10 462 // 463 // We use the PseudoCall to look up the IR function being called to find 464 // its return attributes. 465 const MachineBasicBlock *MBB = MI->getParent(); 466 auto II = MI->getIterator(); 467 if (II == MBB->instr_begin() || 468 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 469 return false; 470 471 const MachineInstr &CallMI = *(--II); 472 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 473 return false; 474 475 auto *CalleeFn = 476 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 477 if (!CalleeFn) 478 return false; 479 480 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 481 if (!IntTy) 482 return false; 483 484 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 485 unsigned BitWidth = IntTy->getBitWidth(); 486 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 487 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 488 continue; 489 } 490 491 if (!AddRegToWorkList(CopySrcReg)) 492 return false; 493 494 break; 495 } 496 497 // For these, we just need to check if the 1st operand is sign extended. 498 case RISCV::BCLRI: 499 case RISCV::BINVI: 500 case RISCV::BSETI: 501 if (MI->getOperand(2).getImm() >= 31) 502 return false; 503 [[fallthrough]]; 504 case RISCV::REM: 505 case RISCV::ANDI: 506 case RISCV::ORI: 507 case RISCV::XORI: 508 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 509 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 510 // Logical operations use a sign extended 12-bit immediate. 511 if (!AddRegToWorkList(MI->getOperand(1).getReg())) 512 return false; 513 514 break; 515 case RISCV::PseudoCCADDW: 516 case RISCV::PseudoCCADDIW: 517 case RISCV::PseudoCCSUBW: 518 case RISCV::PseudoCCSLLW: 519 case RISCV::PseudoCCSRLW: 520 case RISCV::PseudoCCSRAW: 521 case RISCV::PseudoCCSLLIW: 522 case RISCV::PseudoCCSRLIW: 523 case RISCV::PseudoCCSRAIW: 524 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only 525 // need to check if operand 4 is sign extended. 526 if (!AddRegToWorkList(MI->getOperand(4).getReg())) 527 return false; 528 break; 529 case RISCV::REMU: 530 case RISCV::AND: 531 case RISCV::OR: 532 case RISCV::XOR: 533 case RISCV::ANDN: 534 case RISCV::ORN: 535 case RISCV::XNOR: 536 case RISCV::MAX: 537 case RISCV::MAXU: 538 case RISCV::MIN: 539 case RISCV::MINU: 540 case RISCV::PseudoCCMOVGPR: 541 case RISCV::PseudoCCAND: 542 case RISCV::PseudoCCOR: 543 case RISCV::PseudoCCXOR: 544 case RISCV::PHI: { 545 // If all incoming values are sign-extended, the output of AND, OR, XOR, 546 // MIN, MAX, or PHI is also sign-extended. 547 548 // The input registers for PHI are operand 1, 3, ... 549 // The input registers for PseudoCCMOVGPR are 4 and 5. 550 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 551 // The input registers for others are operand 1 and 2. 552 unsigned B = 1, E = 3, D = 1; 553 switch (MI->getOpcode()) { 554 case RISCV::PHI: 555 E = MI->getNumOperands(); 556 D = 2; 557 break; 558 case RISCV::PseudoCCMOVGPR: 559 B = 4; 560 E = 6; 561 break; 562 case RISCV::PseudoCCAND: 563 case RISCV::PseudoCCOR: 564 case RISCV::PseudoCCXOR: 565 B = 4; 566 E = 7; 567 break; 568 } 569 570 for (unsigned I = B; I != E; I += D) { 571 if (!MI->getOperand(I).isReg()) 572 return false; 573 574 if (!AddRegToWorkList(MI->getOperand(I).getReg())) 575 return false; 576 } 577 578 break; 579 } 580 581 case RISCV::CZERO_EQZ: 582 case RISCV::CZERO_NEZ: 583 case RISCV::VT_MASKC: 584 case RISCV::VT_MASKCN: 585 // Instructions return zero or operand 1. Result is sign extended if 586 // operand 1 is sign extended. 587 if (!AddRegToWorkList(MI->getOperand(1).getReg())) 588 return false; 589 break; 590 591 // With these opcode, we can "fix" them with the W-version 592 // if we know all users of the result only rely on bits 31:0 593 case RISCV::SLLI: 594 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 595 if (MI->getOperand(2).getImm() >= 32) 596 return false; 597 [[fallthrough]]; 598 case RISCV::ADDI: 599 case RISCV::ADD: 600 case RISCV::LD: 601 case RISCV::LWU: 602 case RISCV::MUL: 603 case RISCV::SUB: 604 if (hasAllWUsers(*MI, ST, MRI)) { 605 FixableDef.insert(MI); 606 break; 607 } 608 return false; 609 } 610 } 611 612 // If we get here, then every node we visited produces a sign extended value 613 // or propagated sign extended values. So the result must be sign extended. 614 return true; 615 } 616 617 static unsigned getWOp(unsigned Opcode) { 618 switch (Opcode) { 619 case RISCV::ADDI: 620 return RISCV::ADDIW; 621 case RISCV::ADD: 622 return RISCV::ADDW; 623 case RISCV::LD: 624 case RISCV::LWU: 625 return RISCV::LW; 626 case RISCV::MUL: 627 return RISCV::MULW; 628 case RISCV::SLLI: 629 return RISCV::SLLIW; 630 case RISCV::SUB: 631 return RISCV::SUBW; 632 default: 633 llvm_unreachable("Unexpected opcode for replacement with W variant"); 634 } 635 } 636 637 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 638 const RISCVInstrInfo &TII, 639 const RISCVSubtarget &ST, 640 MachineRegisterInfo &MRI) { 641 if (DisableSExtWRemoval) 642 return false; 643 644 bool MadeChange = false; 645 for (MachineBasicBlock &MBB : MF) { 646 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { 647 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 648 if (!RISCV::isSEXT_W(MI)) 649 continue; 650 651 Register SrcReg = MI.getOperand(1).getReg(); 652 653 SmallPtrSet<MachineInstr *, 4> FixableDefs; 654 655 // If all users only use the lower bits, this sext.w is redundant. 656 // Or if all definitions reaching MI sign-extend their output, 657 // then sext.w is redundant. 658 if (!hasAllWUsers(MI, ST, MRI) && 659 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 660 continue; 661 662 Register DstReg = MI.getOperand(0).getReg(); 663 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 664 continue; 665 666 // Convert Fixable instructions to their W versions. 667 for (MachineInstr *Fixable : FixableDefs) { 668 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 669 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 670 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 671 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 672 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 673 LLVM_DEBUG(dbgs() << " with " << *Fixable); 674 ++NumTransformedToWInstrs; 675 } 676 677 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 678 MRI.replaceRegWith(DstReg, SrcReg); 679 MRI.clearKillFlags(SrcReg); 680 MI.eraseFromParent(); 681 ++NumRemovedSExtW; 682 MadeChange = true; 683 } 684 } 685 686 return MadeChange; 687 } 688 689 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 690 const RISCVInstrInfo &TII, 691 const RISCVSubtarget &ST, 692 MachineRegisterInfo &MRI) { 693 bool MadeChange = false; 694 for (MachineBasicBlock &MBB : MF) { 695 for (MachineInstr &MI : MBB) { 696 unsigned Opc; 697 switch (MI.getOpcode()) { 698 default: 699 continue; 700 case RISCV::ADDW: Opc = RISCV::ADD; break; 701 case RISCV::ADDIW: Opc = RISCV::ADDI; break; 702 case RISCV::MULW: Opc = RISCV::MUL; break; 703 case RISCV::SLLIW: Opc = RISCV::SLLI; break; 704 } 705 706 if (hasAllWUsers(MI, ST, MRI)) { 707 MI.setDesc(TII.get(Opc)); 708 MadeChange = true; 709 } 710 } 711 } 712 713 return MadeChange; 714 } 715 716 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, 717 const RISCVInstrInfo &TII, 718 const RISCVSubtarget &ST, 719 MachineRegisterInfo &MRI) { 720 bool MadeChange = false; 721 for (MachineBasicBlock &MBB : MF) { 722 for (MachineInstr &MI : MBB) { 723 unsigned WOpc; 724 // TODO: Add more? 725 switch (MI.getOpcode()) { 726 default: 727 continue; 728 case RISCV::ADD: 729 WOpc = RISCV::ADDW; 730 break; 731 case RISCV::ADDI: 732 WOpc = RISCV::ADDIW; 733 break; 734 case RISCV::SUB: 735 WOpc = RISCV::SUBW; 736 break; 737 case RISCV::MUL: 738 WOpc = RISCV::MULW; 739 break; 740 case RISCV::SLLI: 741 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 742 if (MI.getOperand(2).getImm() >= 32) 743 continue; 744 WOpc = RISCV::SLLIW; 745 break; 746 case RISCV::LD: 747 case RISCV::LWU: 748 WOpc = RISCV::LW; 749 break; 750 } 751 752 if (hasAllWUsers(MI, ST, MRI)) { 753 LLVM_DEBUG(dbgs() << "Replacing " << MI); 754 MI.setDesc(TII.get(WOpc)); 755 MI.clearFlag(MachineInstr::MIFlag::NoSWrap); 756 MI.clearFlag(MachineInstr::MIFlag::NoUWrap); 757 MI.clearFlag(MachineInstr::MIFlag::IsExact); 758 LLVM_DEBUG(dbgs() << " with " << MI); 759 ++NumTransformedToWInstrs; 760 MadeChange = true; 761 } 762 } 763 } 764 765 return MadeChange; 766 } 767 768 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 769 if (skipFunction(MF.getFunction())) 770 return false; 771 772 MachineRegisterInfo &MRI = MF.getRegInfo(); 773 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 774 const RISCVInstrInfo &TII = *ST.getInstrInfo(); 775 776 if (!ST.is64Bit()) 777 return false; 778 779 bool MadeChange = false; 780 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 781 782 if (!(DisableStripWSuffix || ST.preferWInst())) 783 MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 784 785 if (ST.preferWInst()) 786 MadeChange |= appendWSuffixes(MF, TII, ST, MRI); 787 788 return MadeChange; 789 } 790