1 //===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass reduces the VL where possible at the MI level, before VSETVLI 10 // instructions are inserted. 11 // 12 // The purpose of this optimization is to make the VL argument, for instructions 13 // that have a VL argument, as small as possible. This is implemented by 14 // visiting each instruction in reverse order and checking that if it has a VL 15 // argument, whether the VL can be reduced. 16 // 17 //===---------------------------------------------------------------------===// 18 19 #include "RISCV.h" 20 #include "RISCVSubtarget.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/CodeGen/MachineDominators.h" 23 #include "llvm/CodeGen/MachineFunctionPass.h" 24 #include "llvm/InitializePasses.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "riscv-vl-optimizer" 29 #define PASS_NAME "RISC-V VL Optimizer" 30 31 namespace { 32 33 class RISCVVLOptimizer : public MachineFunctionPass { 34 const MachineRegisterInfo *MRI; 35 const MachineDominatorTree *MDT; 36 37 public: 38 static char ID; 39 40 RISCVVLOptimizer() : MachineFunctionPass(ID) {} 41 42 bool runOnMachineFunction(MachineFunction &MF) override; 43 44 void getAnalysisUsage(AnalysisUsage &AU) const override { 45 AU.setPreservesCFG(); 46 AU.addRequired<MachineDominatorTreeWrapperPass>(); 47 MachineFunctionPass::getAnalysisUsage(AU); 48 } 49 50 StringRef getPassName() const override { return PASS_NAME; } 51 52 private: 53 std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp); 54 /// Returns the largest common VL MachineOperand that may be used to optimize 55 /// MI. Returns std::nullopt if it failed to find a suitable VL. 56 std::optional<MachineOperand> checkUsers(MachineInstr &MI); 57 bool tryReduceVL(MachineInstr &MI); 58 bool isCandidate(const MachineInstr &MI) const; 59 60 /// For a given instruction, records what elements of it are demanded by 61 /// downstream users. 62 DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs; 63 }; 64 65 } // end anonymous namespace 66 67 char RISCVVLOptimizer::ID = 0; 68 INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) 69 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 70 INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) 71 72 FunctionPass *llvm::createRISCVVLOptimizerPass() { 73 return new RISCVVLOptimizer(); 74 } 75 76 /// Return true if R is a physical or virtual vector register, false otherwise. 77 static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { 78 if (R.isPhysical()) 79 return RISCV::VRRegClass.contains(R); 80 const TargetRegisterClass *RC = MRI->getRegClass(R); 81 return RISCVRI::isVRegClass(RC->TSFlags); 82 } 83 84 /// Represents the EMUL and EEW of a MachineOperand. 85 struct OperandInfo { 86 // Represent as 1,2,4,8, ... and fractional indicator. This is because 87 // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. 88 // For example, a mask operand can have an EMUL less than MF8. 89 std::optional<std::pair<unsigned, bool>> EMUL; 90 91 unsigned Log2EEW; 92 93 OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) 94 : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} 95 96 OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) 97 : EMUL(EMUL), Log2EEW(Log2EEW) {} 98 99 OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} 100 101 OperandInfo() = delete; 102 103 static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { 104 return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && 105 A.EMUL->second == B.EMUL->second; 106 } 107 108 static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { 109 return A.Log2EEW == B.Log2EEW; 110 } 111 112 void print(raw_ostream &OS) const { 113 if (EMUL) { 114 OS << "EMUL: m"; 115 if (EMUL->second) 116 OS << "f"; 117 OS << EMUL->first; 118 } else 119 OS << "EMUL: unknown\n"; 120 OS << ", EEW: " << (1 << Log2EEW); 121 } 122 }; 123 124 LLVM_ATTRIBUTE_UNUSED 125 static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { 126 OI.print(OS); 127 return OS; 128 } 129 130 LLVM_ATTRIBUTE_UNUSED 131 static raw_ostream &operator<<(raw_ostream &OS, 132 const std::optional<OperandInfo> &OI) { 133 if (OI) 134 OI->print(OS); 135 else 136 OS << "nullopt"; 137 return OS; 138 } 139 140 namespace llvm { 141 namespace RISCVVType { 142 /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and 143 /// SEW are from the TSFlags of MI. 144 static std::pair<unsigned, bool> 145 getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { 146 RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags); 147 auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL); 148 unsigned MILog2SEW = 149 MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); 150 151 // Mask instructions will have 0 as the SEW operand. But the LMUL of these 152 // instructions is calculated is as if the SEW operand was 3 (e8). 153 if (MILog2SEW == 0) 154 MILog2SEW = 3; 155 156 unsigned MISEW = 1 << MILog2SEW; 157 158 unsigned EEW = 1 << Log2EEW; 159 // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD 160 // to put fraction in simplest form. 161 unsigned Num = EEW, Denom = MISEW; 162 int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL) 163 : std::gcd(Num * MILMUL, Denom); 164 Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD; 165 Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD; 166 return std::make_pair(Num > Denom ? Num : Denom, Denom > Num); 167 } 168 } // end namespace RISCVVType 169 } // end namespace llvm 170 171 /// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). 172 /// SEW comes from TSFlags of MI. 173 static unsigned getIntegerExtensionOperandEEW(unsigned Factor, 174 const MachineInstr &MI, 175 const MachineOperand &MO) { 176 unsigned MILog2SEW = 177 MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); 178 179 if (MO.getOperandNo() == 0) 180 return MILog2SEW; 181 182 unsigned MISEW = 1 << MILog2SEW; 183 unsigned EEW = MISEW / Factor; 184 unsigned Log2EEW = Log2_32(EEW); 185 186 return Log2EEW; 187 } 188 189 /// Check whether MO is a mask operand of MI. 190 static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, 191 const MachineRegisterInfo *MRI) { 192 193 if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI)) 194 return false; 195 196 const MCInstrDesc &Desc = MI.getDesc(); 197 return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; 198 } 199 200 static std::optional<unsigned> 201 getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { 202 const MachineInstr &MI = *MO.getParent(); 203 const RISCVVPseudosTable::PseudoInfo *RVV = 204 RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); 205 assert(RVV && "Could not find MI in PseudoTable"); 206 207 // MI has a SEW associated with it. The RVV specification defines 208 // the EEW of each operand and definition in relation to MI.SEW. 209 unsigned MILog2SEW = 210 MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); 211 212 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()); 213 const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags); 214 215 bool IsMODef = MO.getOperandNo() == 0; 216 217 // All mask operands have EEW=1 218 if (isMaskOperand(MI, MO, MRI)) 219 return 0; 220 221 // switch against BaseInstr to reduce number of cases that need to be 222 // considered. 223 switch (RVV->BaseInstr) { 224 225 // 6. Configuration-Setting Instructions 226 // Configuration setting instructions do not read or write vector registers 227 case RISCV::VSETIVLI: 228 case RISCV::VSETVL: 229 case RISCV::VSETVLI: 230 llvm_unreachable("Configuration setting instructions do not read or write " 231 "vector registers"); 232 233 // Vector Loads and Stores 234 // Vector Unit-Stride Instructions 235 // Vector Strided Instructions 236 /// Dest EEW encoded in the instruction 237 case RISCV::VLM_V: 238 case RISCV::VSM_V: 239 return 0; 240 case RISCV::VLE8_V: 241 case RISCV::VSE8_V: 242 case RISCV::VLSE8_V: 243 case RISCV::VSSE8_V: 244 return 3; 245 case RISCV::VLE16_V: 246 case RISCV::VSE16_V: 247 case RISCV::VLSE16_V: 248 case RISCV::VSSE16_V: 249 return 4; 250 case RISCV::VLE32_V: 251 case RISCV::VSE32_V: 252 case RISCV::VLSE32_V: 253 case RISCV::VSSE32_V: 254 return 5; 255 case RISCV::VLE64_V: 256 case RISCV::VSE64_V: 257 case RISCV::VLSE64_V: 258 case RISCV::VSSE64_V: 259 return 6; 260 261 // Vector Indexed Instructions 262 // vs(o|u)xei<eew>.v 263 // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>. 264 case RISCV::VLUXEI8_V: 265 case RISCV::VLOXEI8_V: 266 case RISCV::VSUXEI8_V: 267 case RISCV::VSOXEI8_V: { 268 if (MO.getOperandNo() == 0) 269 return MILog2SEW; 270 return 3; 271 } 272 case RISCV::VLUXEI16_V: 273 case RISCV::VLOXEI16_V: 274 case RISCV::VSUXEI16_V: 275 case RISCV::VSOXEI16_V: { 276 if (MO.getOperandNo() == 0) 277 return MILog2SEW; 278 return 4; 279 } 280 case RISCV::VLUXEI32_V: 281 case RISCV::VLOXEI32_V: 282 case RISCV::VSUXEI32_V: 283 case RISCV::VSOXEI32_V: { 284 if (MO.getOperandNo() == 0) 285 return MILog2SEW; 286 return 5; 287 } 288 case RISCV::VLUXEI64_V: 289 case RISCV::VLOXEI64_V: 290 case RISCV::VSUXEI64_V: 291 case RISCV::VSOXEI64_V: { 292 if (MO.getOperandNo() == 0) 293 return MILog2SEW; 294 return 6; 295 } 296 297 // Vector Integer Arithmetic Instructions 298 // Vector Single-Width Integer Add and Subtract 299 case RISCV::VADD_VI: 300 case RISCV::VADD_VV: 301 case RISCV::VADD_VX: 302 case RISCV::VSUB_VV: 303 case RISCV::VSUB_VX: 304 case RISCV::VRSUB_VI: 305 case RISCV::VRSUB_VX: 306 // Vector Bitwise Logical Instructions 307 // Vector Single-Width Shift Instructions 308 // EEW=SEW. 309 case RISCV::VAND_VI: 310 case RISCV::VAND_VV: 311 case RISCV::VAND_VX: 312 case RISCV::VOR_VI: 313 case RISCV::VOR_VV: 314 case RISCV::VOR_VX: 315 case RISCV::VXOR_VI: 316 case RISCV::VXOR_VV: 317 case RISCV::VXOR_VX: 318 case RISCV::VSLL_VI: 319 case RISCV::VSLL_VV: 320 case RISCV::VSLL_VX: 321 case RISCV::VSRL_VI: 322 case RISCV::VSRL_VV: 323 case RISCV::VSRL_VX: 324 case RISCV::VSRA_VI: 325 case RISCV::VSRA_VV: 326 case RISCV::VSRA_VX: 327 // Vector Integer Min/Max Instructions 328 // EEW=SEW. 329 case RISCV::VMINU_VV: 330 case RISCV::VMINU_VX: 331 case RISCV::VMIN_VV: 332 case RISCV::VMIN_VX: 333 case RISCV::VMAXU_VV: 334 case RISCV::VMAXU_VX: 335 case RISCV::VMAX_VV: 336 case RISCV::VMAX_VX: 337 // Vector Single-Width Integer Multiply Instructions 338 // Source and Dest EEW=SEW. 339 case RISCV::VMUL_VV: 340 case RISCV::VMUL_VX: 341 case RISCV::VMULH_VV: 342 case RISCV::VMULH_VX: 343 case RISCV::VMULHU_VV: 344 case RISCV::VMULHU_VX: 345 case RISCV::VMULHSU_VV: 346 case RISCV::VMULHSU_VX: 347 // Vector Integer Divide Instructions 348 // EEW=SEW. 349 case RISCV::VDIVU_VV: 350 case RISCV::VDIVU_VX: 351 case RISCV::VDIV_VV: 352 case RISCV::VDIV_VX: 353 case RISCV::VREMU_VV: 354 case RISCV::VREMU_VX: 355 case RISCV::VREM_VV: 356 case RISCV::VREM_VX: 357 // Vector Single-Width Integer Multiply-Add Instructions 358 // EEW=SEW. 359 case RISCV::VMACC_VV: 360 case RISCV::VMACC_VX: 361 case RISCV::VNMSAC_VV: 362 case RISCV::VNMSAC_VX: 363 case RISCV::VMADD_VV: 364 case RISCV::VMADD_VX: 365 case RISCV::VNMSUB_VV: 366 case RISCV::VNMSUB_VX: 367 // Vector Integer Merge Instructions 368 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions 369 // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled 370 // before this switch. 371 case RISCV::VMERGE_VIM: 372 case RISCV::VMERGE_VVM: 373 case RISCV::VMERGE_VXM: 374 case RISCV::VADC_VIM: 375 case RISCV::VADC_VVM: 376 case RISCV::VADC_VXM: 377 case RISCV::VSBC_VVM: 378 case RISCV::VSBC_VXM: 379 // Vector Integer Move Instructions 380 // Vector Fixed-Point Arithmetic Instructions 381 // Vector Single-Width Saturating Add and Subtract 382 // Vector Single-Width Averaging Add and Subtract 383 // EEW=SEW. 384 case RISCV::VMV_V_I: 385 case RISCV::VMV_V_V: 386 case RISCV::VMV_V_X: 387 case RISCV::VSADDU_VI: 388 case RISCV::VSADDU_VV: 389 case RISCV::VSADDU_VX: 390 case RISCV::VSADD_VI: 391 case RISCV::VSADD_VV: 392 case RISCV::VSADD_VX: 393 case RISCV::VSSUBU_VV: 394 case RISCV::VSSUBU_VX: 395 case RISCV::VSSUB_VV: 396 case RISCV::VSSUB_VX: 397 case RISCV::VAADDU_VV: 398 case RISCV::VAADDU_VX: 399 case RISCV::VAADD_VV: 400 case RISCV::VAADD_VX: 401 case RISCV::VASUBU_VV: 402 case RISCV::VASUBU_VX: 403 case RISCV::VASUB_VV: 404 case RISCV::VASUB_VX: 405 // Vector Single-Width Fractional Multiply with Rounding and Saturation 406 // EEW=SEW. The instruction produces 2*SEW product internally but 407 // saturates to fit into SEW bits. 408 case RISCV::VSMUL_VV: 409 case RISCV::VSMUL_VX: 410 // Vector Single-Width Scaling Shift Instructions 411 // EEW=SEW. 412 case RISCV::VSSRL_VI: 413 case RISCV::VSSRL_VV: 414 case RISCV::VSSRL_VX: 415 case RISCV::VSSRA_VI: 416 case RISCV::VSSRA_VV: 417 case RISCV::VSSRA_VX: 418 // Vector Permutation Instructions 419 // Integer Scalar Move Instructions 420 // Floating-Point Scalar Move Instructions 421 // EEW=SEW. 422 case RISCV::VMV_X_S: 423 case RISCV::VMV_S_X: 424 case RISCV::VFMV_F_S: 425 case RISCV::VFMV_S_F: 426 // Vector Slide Instructions 427 // EEW=SEW. 428 case RISCV::VSLIDEUP_VI: 429 case RISCV::VSLIDEUP_VX: 430 case RISCV::VSLIDEDOWN_VI: 431 case RISCV::VSLIDEDOWN_VX: 432 case RISCV::VSLIDE1UP_VX: 433 case RISCV::VFSLIDE1UP_VF: 434 case RISCV::VSLIDE1DOWN_VX: 435 case RISCV::VFSLIDE1DOWN_VF: 436 // Vector Register Gather Instructions 437 // EEW=SEW. For mask operand, EEW=1. 438 case RISCV::VRGATHER_VI: 439 case RISCV::VRGATHER_VV: 440 case RISCV::VRGATHER_VX: 441 // Vector Compress Instruction 442 // EEW=SEW. 443 case RISCV::VCOMPRESS_VM: 444 // Vector Element Index Instruction 445 case RISCV::VID_V: 446 // Vector Single-Width Floating-Point Add/Subtract Instructions 447 case RISCV::VFADD_VF: 448 case RISCV::VFADD_VV: 449 case RISCV::VFSUB_VF: 450 case RISCV::VFSUB_VV: 451 case RISCV::VFRSUB_VF: 452 // Vector Single-Width Floating-Point Multiply/Divide Instructions 453 case RISCV::VFMUL_VF: 454 case RISCV::VFMUL_VV: 455 case RISCV::VFDIV_VF: 456 case RISCV::VFDIV_VV: 457 case RISCV::VFRDIV_VF: 458 // Vector Floating-Point Square-Root Instruction 459 case RISCV::VFSQRT_V: 460 // Vector Floating-Point Reciprocal Square-Root Estimate Instruction 461 case RISCV::VFRSQRT7_V: 462 // Vector Floating-Point Reciprocal Estimate Instruction 463 case RISCV::VFREC7_V: 464 // Vector Floating-Point MIN/MAX Instructions 465 case RISCV::VFMIN_VF: 466 case RISCV::VFMIN_VV: 467 case RISCV::VFMAX_VF: 468 case RISCV::VFMAX_VV: 469 // Vector Floating-Point Sign-Injection Instructions 470 case RISCV::VFSGNJ_VF: 471 case RISCV::VFSGNJ_VV: 472 case RISCV::VFSGNJN_VV: 473 case RISCV::VFSGNJN_VF: 474 case RISCV::VFSGNJX_VF: 475 case RISCV::VFSGNJX_VV: 476 // Vector Floating-Point Classify Instruction 477 case RISCV::VFCLASS_V: 478 // Vector Floating-Point Move Instruction 479 case RISCV::VFMV_V_F: 480 // Single-Width Floating-Point/Integer Type-Convert Instructions 481 case RISCV::VFCVT_XU_F_V: 482 case RISCV::VFCVT_X_F_V: 483 case RISCV::VFCVT_RTZ_XU_F_V: 484 case RISCV::VFCVT_RTZ_X_F_V: 485 case RISCV::VFCVT_F_XU_V: 486 case RISCV::VFCVT_F_X_V: 487 // Vector Floating-Point Merge Instruction 488 case RISCV::VFMERGE_VFM: 489 // Vector count population in mask vcpop.m 490 // vfirst find-first-set mask bit 491 case RISCV::VCPOP_M: 492 case RISCV::VFIRST_M: 493 return MILog2SEW; 494 495 // Vector Widening Integer Add/Subtract 496 // Def uses EEW=2*SEW . Operands use EEW=SEW. 497 case RISCV::VWADDU_VV: 498 case RISCV::VWADDU_VX: 499 case RISCV::VWSUBU_VV: 500 case RISCV::VWSUBU_VX: 501 case RISCV::VWADD_VV: 502 case RISCV::VWADD_VX: 503 case RISCV::VWSUB_VV: 504 case RISCV::VWSUB_VX: 505 case RISCV::VWSLL_VI: 506 // Vector Widening Integer Multiply Instructions 507 // Destination EEW=2*SEW. Source EEW=SEW. 508 case RISCV::VWMUL_VV: 509 case RISCV::VWMUL_VX: 510 case RISCV::VWMULSU_VV: 511 case RISCV::VWMULSU_VX: 512 case RISCV::VWMULU_VV: 513 case RISCV::VWMULU_VX: 514 // Vector Widening Integer Multiply-Add Instructions 515 // Destination EEW=2*SEW. Source EEW=SEW. 516 // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which 517 // is then added to the 2*SEW-bit Dest. These instructions never have a 518 // passthru operand. 519 case RISCV::VWMACCU_VV: 520 case RISCV::VWMACCU_VX: 521 case RISCV::VWMACC_VV: 522 case RISCV::VWMACC_VX: 523 case RISCV::VWMACCSU_VV: 524 case RISCV::VWMACCSU_VX: 525 case RISCV::VWMACCUS_VX: 526 // Vector Widening Floating-Point Fused Multiply-Add Instructions 527 case RISCV::VFWMACC_VF: 528 case RISCV::VFWMACC_VV: 529 case RISCV::VFWNMACC_VF: 530 case RISCV::VFWNMACC_VV: 531 case RISCV::VFWMSAC_VF: 532 case RISCV::VFWMSAC_VV: 533 case RISCV::VFWNMSAC_VF: 534 case RISCV::VFWNMSAC_VV: 535 // Vector Widening Floating-Point Add/Subtract Instructions 536 // Dest EEW=2*SEW. Source EEW=SEW. 537 case RISCV::VFWADD_VV: 538 case RISCV::VFWADD_VF: 539 case RISCV::VFWSUB_VV: 540 case RISCV::VFWSUB_VF: 541 // Vector Widening Floating-Point Multiply 542 case RISCV::VFWMUL_VF: 543 case RISCV::VFWMUL_VV: 544 // Widening Floating-Point/Integer Type-Convert Instructions 545 case RISCV::VFWCVT_XU_F_V: 546 case RISCV::VFWCVT_X_F_V: 547 case RISCV::VFWCVT_RTZ_XU_F_V: 548 case RISCV::VFWCVT_RTZ_X_F_V: 549 case RISCV::VFWCVT_F_XU_V: 550 case RISCV::VFWCVT_F_X_V: 551 case RISCV::VFWCVT_F_F_V: 552 case RISCV::VFWCVTBF16_F_F_V: 553 return IsMODef ? MILog2SEW + 1 : MILog2SEW; 554 555 // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. 556 case RISCV::VWADDU_WV: 557 case RISCV::VWADDU_WX: 558 case RISCV::VWSUBU_WV: 559 case RISCV::VWSUBU_WX: 560 case RISCV::VWADD_WV: 561 case RISCV::VWADD_WX: 562 case RISCV::VWSUB_WV: 563 case RISCV::VWSUB_WX: 564 // Vector Widening Floating-Point Add/Subtract Instructions 565 case RISCV::VFWADD_WF: 566 case RISCV::VFWADD_WV: 567 case RISCV::VFWSUB_WF: 568 case RISCV::VFWSUB_WV: { 569 bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2 570 : MO.getOperandNo() == 1; 571 bool TwoTimes = IsMODef || IsOp1; 572 return TwoTimes ? MILog2SEW + 1 : MILog2SEW; 573 } 574 575 // Vector Integer Extension 576 case RISCV::VZEXT_VF2: 577 case RISCV::VSEXT_VF2: 578 return getIntegerExtensionOperandEEW(2, MI, MO); 579 case RISCV::VZEXT_VF4: 580 case RISCV::VSEXT_VF4: 581 return getIntegerExtensionOperandEEW(4, MI, MO); 582 case RISCV::VZEXT_VF8: 583 case RISCV::VSEXT_VF8: 584 return getIntegerExtensionOperandEEW(8, MI, MO); 585 586 // Vector Narrowing Integer Right Shift Instructions 587 // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW 588 case RISCV::VNSRL_WX: 589 case RISCV::VNSRL_WI: 590 case RISCV::VNSRL_WV: 591 case RISCV::VNSRA_WI: 592 case RISCV::VNSRA_WV: 593 case RISCV::VNSRA_WX: 594 // Vector Narrowing Fixed-Point Clip Instructions 595 // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. 596 case RISCV::VNCLIPU_WI: 597 case RISCV::VNCLIPU_WV: 598 case RISCV::VNCLIPU_WX: 599 case RISCV::VNCLIP_WI: 600 case RISCV::VNCLIP_WV: 601 case RISCV::VNCLIP_WX: 602 // Narrowing Floating-Point/Integer Type-Convert Instructions 603 case RISCV::VFNCVT_XU_F_W: 604 case RISCV::VFNCVT_X_F_W: 605 case RISCV::VFNCVT_RTZ_XU_F_W: 606 case RISCV::VFNCVT_RTZ_X_F_W: 607 case RISCV::VFNCVT_F_XU_W: 608 case RISCV::VFNCVT_F_X_W: 609 case RISCV::VFNCVT_F_F_W: 610 case RISCV::VFNCVT_ROD_F_F_W: 611 case RISCV::VFNCVTBF16_F_F_W: { 612 assert(!IsTied); 613 bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; 614 bool TwoTimes = IsOp1; 615 return TwoTimes ? MILog2SEW + 1 : MILog2SEW; 616 } 617 618 // Vector Mask Instructions 619 // Vector Mask-Register Logical Instructions 620 // vmsbf.m set-before-first mask bit 621 // vmsif.m set-including-first mask bit 622 // vmsof.m set-only-first mask bit 623 // EEW=1 624 // We handle the cases when operand is a v0 mask operand above the switch, 625 // but these instructions may use non-v0 mask operands and need to be handled 626 // specifically. 627 case RISCV::VMAND_MM: 628 case RISCV::VMNAND_MM: 629 case RISCV::VMANDN_MM: 630 case RISCV::VMXOR_MM: 631 case RISCV::VMOR_MM: 632 case RISCV::VMNOR_MM: 633 case RISCV::VMORN_MM: 634 case RISCV::VMXNOR_MM: 635 case RISCV::VMSBF_M: 636 case RISCV::VMSIF_M: 637 case RISCV::VMSOF_M: { 638 return MILog2SEW; 639 } 640 641 // Vector Iota Instruction 642 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled 643 // before this switch. 644 case RISCV::VIOTA_M: { 645 if (IsMODef || MO.getOperandNo() == 1) 646 return MILog2SEW; 647 return 0; 648 } 649 650 // Vector Integer Compare Instructions 651 // Dest EEW=1. Source EEW=SEW. 652 case RISCV::VMSEQ_VI: 653 case RISCV::VMSEQ_VV: 654 case RISCV::VMSEQ_VX: 655 case RISCV::VMSNE_VI: 656 case RISCV::VMSNE_VV: 657 case RISCV::VMSNE_VX: 658 case RISCV::VMSLTU_VV: 659 case RISCV::VMSLTU_VX: 660 case RISCV::VMSLT_VV: 661 case RISCV::VMSLT_VX: 662 case RISCV::VMSLEU_VV: 663 case RISCV::VMSLEU_VI: 664 case RISCV::VMSLEU_VX: 665 case RISCV::VMSLE_VV: 666 case RISCV::VMSLE_VI: 667 case RISCV::VMSLE_VX: 668 case RISCV::VMSGTU_VI: 669 case RISCV::VMSGTU_VX: 670 case RISCV::VMSGT_VI: 671 case RISCV::VMSGT_VX: 672 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions 673 // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. 674 case RISCV::VMADC_VIM: 675 case RISCV::VMADC_VVM: 676 case RISCV::VMADC_VXM: 677 case RISCV::VMSBC_VVM: 678 case RISCV::VMSBC_VXM: 679 // Dest EEW=1. Source EEW=SEW. 680 case RISCV::VMADC_VV: 681 case RISCV::VMADC_VI: 682 case RISCV::VMADC_VX: 683 case RISCV::VMSBC_VV: 684 case RISCV::VMSBC_VX: 685 // 13.13. Vector Floating-Point Compare Instructions 686 // Dest EEW=1. Source EEW=SEW 687 case RISCV::VMFEQ_VF: 688 case RISCV::VMFEQ_VV: 689 case RISCV::VMFNE_VF: 690 case RISCV::VMFNE_VV: 691 case RISCV::VMFLT_VF: 692 case RISCV::VMFLT_VV: 693 case RISCV::VMFLE_VF: 694 case RISCV::VMFLE_VV: 695 case RISCV::VMFGT_VF: 696 case RISCV::VMFGE_VF: { 697 if (IsMODef) 698 return 0; 699 return MILog2SEW; 700 } 701 702 // Vector Reduction Operations 703 // Vector Single-Width Integer Reduction Instructions 704 case RISCV::VREDAND_VS: 705 case RISCV::VREDMAX_VS: 706 case RISCV::VREDMAXU_VS: 707 case RISCV::VREDMIN_VS: 708 case RISCV::VREDMINU_VS: 709 case RISCV::VREDOR_VS: 710 case RISCV::VREDSUM_VS: 711 case RISCV::VREDXOR_VS: 712 // Vector Single-Width Floating-Point Reduction Instructions 713 case RISCV::VFREDMAX_VS: 714 case RISCV::VFREDMIN_VS: 715 case RISCV::VFREDOSUM_VS: 716 case RISCV::VFREDUSUM_VS: { 717 return MILog2SEW; 718 } 719 720 // Vector Widening Integer Reduction Instructions 721 // The Dest and VS1 read only element 0 for the vector register. Return 722 // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL. 723 case RISCV::VWREDSUM_VS: 724 case RISCV::VWREDSUMU_VS: 725 // Vector Widening Floating-Point Reduction Instructions 726 case RISCV::VFWREDOSUM_VS: 727 case RISCV::VFWREDUSUM_VS: { 728 bool TwoTimes = IsMODef || MO.getOperandNo() == 3; 729 return TwoTimes ? MILog2SEW + 1 : MILog2SEW; 730 } 731 732 default: 733 return std::nullopt; 734 } 735 } 736 737 static std::optional<OperandInfo> 738 getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { 739 const MachineInstr &MI = *MO.getParent(); 740 const RISCVVPseudosTable::PseudoInfo *RVV = 741 RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); 742 assert(RVV && "Could not find MI in PseudoTable"); 743 744 std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); 745 if (!Log2EEW) 746 return std::nullopt; 747 748 switch (RVV->BaseInstr) { 749 // Vector Reduction Operations 750 // Vector Single-Width Integer Reduction Instructions 751 // Vector Widening Integer Reduction Instructions 752 // Vector Widening Floating-Point Reduction Instructions 753 // The Dest and VS1 only read element 0 of the vector register. Return just 754 // the EEW for these. 755 case RISCV::VREDAND_VS: 756 case RISCV::VREDMAX_VS: 757 case RISCV::VREDMAXU_VS: 758 case RISCV::VREDMIN_VS: 759 case RISCV::VREDMINU_VS: 760 case RISCV::VREDOR_VS: 761 case RISCV::VREDSUM_VS: 762 case RISCV::VREDXOR_VS: 763 case RISCV::VWREDSUM_VS: 764 case RISCV::VWREDSUMU_VS: 765 case RISCV::VFWREDOSUM_VS: 766 case RISCV::VFWREDUSUM_VS: 767 if (MO.getOperandNo() != 2) 768 return OperandInfo(*Log2EEW); 769 break; 770 }; 771 772 // All others have EMUL=EEW/SEW*LMUL 773 return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), 774 *Log2EEW); 775 } 776 777 /// Return true if this optimization should consider MI for VL reduction. This 778 /// white-list approach simplifies this optimization for instructions that may 779 /// have more complex semantics with relation to how it uses VL. 780 static bool isSupportedInstr(const MachineInstr &MI) { 781 const RISCVVPseudosTable::PseudoInfo *RVV = 782 RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); 783 784 if (!RVV) 785 return false; 786 787 switch (RVV->BaseInstr) { 788 // Vector Unit-Stride Instructions 789 // Vector Strided Instructions 790 case RISCV::VLM_V: 791 case RISCV::VLE8_V: 792 case RISCV::VLSE8_V: 793 case RISCV::VLE16_V: 794 case RISCV::VLSE16_V: 795 case RISCV::VLE32_V: 796 case RISCV::VLSE32_V: 797 case RISCV::VLE64_V: 798 case RISCV::VLSE64_V: 799 // Vector Indexed Instructions 800 case RISCV::VLUXEI8_V: 801 case RISCV::VLOXEI8_V: 802 case RISCV::VLUXEI16_V: 803 case RISCV::VLOXEI16_V: 804 case RISCV::VLUXEI32_V: 805 case RISCV::VLOXEI32_V: 806 case RISCV::VLUXEI64_V: 807 case RISCV::VLOXEI64_V: { 808 for (const MachineMemOperand *MMO : MI.memoperands()) 809 if (MMO->isVolatile()) 810 return false; 811 return true; 812 } 813 814 // Vector Single-Width Integer Add and Subtract 815 case RISCV::VADD_VI: 816 case RISCV::VADD_VV: 817 case RISCV::VADD_VX: 818 case RISCV::VSUB_VV: 819 case RISCV::VSUB_VX: 820 case RISCV::VRSUB_VI: 821 case RISCV::VRSUB_VX: 822 // Vector Bitwise Logical Instructions 823 // Vector Single-Width Shift Instructions 824 case RISCV::VAND_VI: 825 case RISCV::VAND_VV: 826 case RISCV::VAND_VX: 827 case RISCV::VOR_VI: 828 case RISCV::VOR_VV: 829 case RISCV::VOR_VX: 830 case RISCV::VXOR_VI: 831 case RISCV::VXOR_VV: 832 case RISCV::VXOR_VX: 833 case RISCV::VSLL_VI: 834 case RISCV::VSLL_VV: 835 case RISCV::VSLL_VX: 836 case RISCV::VSRL_VI: 837 case RISCV::VSRL_VV: 838 case RISCV::VSRL_VX: 839 case RISCV::VSRA_VI: 840 case RISCV::VSRA_VV: 841 case RISCV::VSRA_VX: 842 // Vector Widening Integer Add/Subtract 843 case RISCV::VWADDU_VV: 844 case RISCV::VWADDU_VX: 845 case RISCV::VWSUBU_VV: 846 case RISCV::VWSUBU_VX: 847 case RISCV::VWADD_VV: 848 case RISCV::VWADD_VX: 849 case RISCV::VWSUB_VV: 850 case RISCV::VWSUB_VX: 851 case RISCV::VWADDU_WV: 852 case RISCV::VWADDU_WX: 853 case RISCV::VWSUBU_WV: 854 case RISCV::VWSUBU_WX: 855 case RISCV::VWADD_WV: 856 case RISCV::VWADD_WX: 857 case RISCV::VWSUB_WV: 858 case RISCV::VWSUB_WX: 859 // Vector Integer Extension 860 case RISCV::VZEXT_VF2: 861 case RISCV::VSEXT_VF2: 862 case RISCV::VZEXT_VF4: 863 case RISCV::VSEXT_VF4: 864 case RISCV::VZEXT_VF8: 865 case RISCV::VSEXT_VF8: 866 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions 867 // FIXME: Add support 868 case RISCV::VMADC_VV: 869 case RISCV::VMADC_VI: 870 case RISCV::VMADC_VX: 871 case RISCV::VMSBC_VV: 872 case RISCV::VMSBC_VX: 873 // Vector Narrowing Integer Right Shift Instructions 874 case RISCV::VNSRL_WX: 875 case RISCV::VNSRL_WI: 876 case RISCV::VNSRL_WV: 877 case RISCV::VNSRA_WI: 878 case RISCV::VNSRA_WV: 879 case RISCV::VNSRA_WX: 880 // Vector Integer Compare Instructions 881 case RISCV::VMSEQ_VI: 882 case RISCV::VMSEQ_VV: 883 case RISCV::VMSEQ_VX: 884 case RISCV::VMSNE_VI: 885 case RISCV::VMSNE_VV: 886 case RISCV::VMSNE_VX: 887 case RISCV::VMSLTU_VV: 888 case RISCV::VMSLTU_VX: 889 case RISCV::VMSLT_VV: 890 case RISCV::VMSLT_VX: 891 case RISCV::VMSLEU_VV: 892 case RISCV::VMSLEU_VI: 893 case RISCV::VMSLEU_VX: 894 case RISCV::VMSLE_VV: 895 case RISCV::VMSLE_VI: 896 case RISCV::VMSLE_VX: 897 case RISCV::VMSGTU_VI: 898 case RISCV::VMSGTU_VX: 899 case RISCV::VMSGT_VI: 900 case RISCV::VMSGT_VX: 901 // Vector Integer Min/Max Instructions 902 case RISCV::VMINU_VV: 903 case RISCV::VMINU_VX: 904 case RISCV::VMIN_VV: 905 case RISCV::VMIN_VX: 906 case RISCV::VMAXU_VV: 907 case RISCV::VMAXU_VX: 908 case RISCV::VMAX_VV: 909 case RISCV::VMAX_VX: 910 // Vector Single-Width Integer Multiply Instructions 911 case RISCV::VMUL_VV: 912 case RISCV::VMUL_VX: 913 case RISCV::VMULH_VV: 914 case RISCV::VMULH_VX: 915 case RISCV::VMULHU_VV: 916 case RISCV::VMULHU_VX: 917 case RISCV::VMULHSU_VV: 918 case RISCV::VMULHSU_VX: 919 // Vector Integer Divide Instructions 920 case RISCV::VDIVU_VV: 921 case RISCV::VDIVU_VX: 922 case RISCV::VDIV_VV: 923 case RISCV::VDIV_VX: 924 case RISCV::VREMU_VV: 925 case RISCV::VREMU_VX: 926 case RISCV::VREM_VV: 927 case RISCV::VREM_VX: 928 // Vector Widening Integer Multiply Instructions 929 case RISCV::VWMUL_VV: 930 case RISCV::VWMUL_VX: 931 case RISCV::VWMULSU_VV: 932 case RISCV::VWMULSU_VX: 933 case RISCV::VWMULU_VV: 934 case RISCV::VWMULU_VX: 935 // Vector Single-Width Integer Multiply-Add Instructions 936 case RISCV::VMACC_VV: 937 case RISCV::VMACC_VX: 938 case RISCV::VNMSAC_VV: 939 case RISCV::VNMSAC_VX: 940 case RISCV::VMADD_VV: 941 case RISCV::VMADD_VX: 942 case RISCV::VNMSUB_VV: 943 case RISCV::VNMSUB_VX: 944 // Vector Integer Merge Instructions 945 case RISCV::VMERGE_VIM: 946 case RISCV::VMERGE_VVM: 947 case RISCV::VMERGE_VXM: 948 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions 949 case RISCV::VADC_VIM: 950 case RISCV::VADC_VVM: 951 case RISCV::VADC_VXM: 952 // Vector Widening Integer Multiply-Add Instructions 953 case RISCV::VWMACCU_VV: 954 case RISCV::VWMACCU_VX: 955 case RISCV::VWMACC_VV: 956 case RISCV::VWMACC_VX: 957 case RISCV::VWMACCSU_VV: 958 case RISCV::VWMACCSU_VX: 959 case RISCV::VWMACCUS_VX: 960 // Vector Integer Merge Instructions 961 // FIXME: Add support 962 // Vector Integer Move Instructions 963 // FIXME: Add support 964 case RISCV::VMV_V_I: 965 case RISCV::VMV_V_X: 966 case RISCV::VMV_V_V: 967 // Vector Single-Width Averaging Add and Subtract 968 case RISCV::VAADDU_VV: 969 case RISCV::VAADDU_VX: 970 case RISCV::VAADD_VV: 971 case RISCV::VAADD_VX: 972 case RISCV::VASUBU_VV: 973 case RISCV::VASUBU_VX: 974 case RISCV::VASUB_VV: 975 case RISCV::VASUB_VX: 976 977 // Vector Crypto 978 case RISCV::VWSLL_VI: 979 980 // Vector Mask Instructions 981 // Vector Mask-Register Logical Instructions 982 // vmsbf.m set-before-first mask bit 983 // vmsif.m set-including-first mask bit 984 // vmsof.m set-only-first mask bit 985 // Vector Iota Instruction 986 // Vector Element Index Instruction 987 case RISCV::VMAND_MM: 988 case RISCV::VMNAND_MM: 989 case RISCV::VMANDN_MM: 990 case RISCV::VMXOR_MM: 991 case RISCV::VMOR_MM: 992 case RISCV::VMNOR_MM: 993 case RISCV::VMORN_MM: 994 case RISCV::VMXNOR_MM: 995 case RISCV::VMSBF_M: 996 case RISCV::VMSIF_M: 997 case RISCV::VMSOF_M: 998 case RISCV::VIOTA_M: 999 case RISCV::VID_V: 1000 // Vector Single-Width Floating-Point Add/Subtract Instructions 1001 case RISCV::VFADD_VF: 1002 case RISCV::VFADD_VV: 1003 case RISCV::VFSUB_VF: 1004 case RISCV::VFSUB_VV: 1005 case RISCV::VFRSUB_VF: 1006 // Vector Widening Floating-Point Add/Subtract Instructions 1007 case RISCV::VFWADD_VV: 1008 case RISCV::VFWADD_VF: 1009 case RISCV::VFWSUB_VV: 1010 case RISCV::VFWSUB_VF: 1011 case RISCV::VFWADD_WF: 1012 case RISCV::VFWADD_WV: 1013 case RISCV::VFWSUB_WF: 1014 case RISCV::VFWSUB_WV: 1015 // Vector Single-Width Floating-Point Multiply/Divide Instructions 1016 case RISCV::VFMUL_VF: 1017 case RISCV::VFMUL_VV: 1018 case RISCV::VFDIV_VF: 1019 case RISCV::VFDIV_VV: 1020 case RISCV::VFRDIV_VF: 1021 // Vector Widening Floating-Point Multiply 1022 case RISCV::VFWMUL_VF: 1023 case RISCV::VFWMUL_VV: 1024 // Vector Floating-Point MIN/MAX Instructions 1025 case RISCV::VFMIN_VF: 1026 case RISCV::VFMIN_VV: 1027 case RISCV::VFMAX_VF: 1028 case RISCV::VFMAX_VV: 1029 // Vector Floating-Point Sign-Injection Instructions 1030 case RISCV::VFSGNJ_VF: 1031 case RISCV::VFSGNJ_VV: 1032 case RISCV::VFSGNJN_VV: 1033 case RISCV::VFSGNJN_VF: 1034 case RISCV::VFSGNJX_VF: 1035 case RISCV::VFSGNJX_VV: 1036 // Vector Floating-Point Compare Instructions 1037 case RISCV::VMFEQ_VF: 1038 case RISCV::VMFEQ_VV: 1039 case RISCV::VMFNE_VF: 1040 case RISCV::VMFNE_VV: 1041 case RISCV::VMFLT_VF: 1042 case RISCV::VMFLT_VV: 1043 case RISCV::VMFLE_VF: 1044 case RISCV::VMFLE_VV: 1045 case RISCV::VMFGT_VF: 1046 case RISCV::VMFGE_VF: 1047 // Single-Width Floating-Point/Integer Type-Convert Instructions 1048 case RISCV::VFCVT_XU_F_V: 1049 case RISCV::VFCVT_X_F_V: 1050 case RISCV::VFCVT_RTZ_XU_F_V: 1051 case RISCV::VFCVT_RTZ_X_F_V: 1052 case RISCV::VFCVT_F_XU_V: 1053 case RISCV::VFCVT_F_X_V: 1054 // Widening Floating-Point/Integer Type-Convert Instructions 1055 case RISCV::VFWCVT_XU_F_V: 1056 case RISCV::VFWCVT_X_F_V: 1057 case RISCV::VFWCVT_RTZ_XU_F_V: 1058 case RISCV::VFWCVT_RTZ_X_F_V: 1059 case RISCV::VFWCVT_F_XU_V: 1060 case RISCV::VFWCVT_F_X_V: 1061 case RISCV::VFWCVT_F_F_V: 1062 case RISCV::VFWCVTBF16_F_F_V: 1063 // Narrowing Floating-Point/Integer Type-Convert Instructions 1064 case RISCV::VFNCVT_XU_F_W: 1065 case RISCV::VFNCVT_X_F_W: 1066 case RISCV::VFNCVT_RTZ_XU_F_W: 1067 case RISCV::VFNCVT_RTZ_X_F_W: 1068 case RISCV::VFNCVT_F_XU_W: 1069 case RISCV::VFNCVT_F_X_W: 1070 case RISCV::VFNCVT_F_F_W: 1071 case RISCV::VFNCVT_ROD_F_F_W: 1072 case RISCV::VFNCVTBF16_F_F_W: 1073 return true; 1074 } 1075 1076 return false; 1077 } 1078 1079 /// Return true if MO is a vector operand but is used as a scalar operand. 1080 static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) { 1081 MachineInstr *MI = MO.getParent(); 1082 const RISCVVPseudosTable::PseudoInfo *RVV = 1083 RISCVVPseudosTable::getPseudoInfo(MI->getOpcode()); 1084 1085 if (!RVV) 1086 return false; 1087 1088 switch (RVV->BaseInstr) { 1089 // Reductions only use vs1[0] of vs1 1090 case RISCV::VREDAND_VS: 1091 case RISCV::VREDMAX_VS: 1092 case RISCV::VREDMAXU_VS: 1093 case RISCV::VREDMIN_VS: 1094 case RISCV::VREDMINU_VS: 1095 case RISCV::VREDOR_VS: 1096 case RISCV::VREDSUM_VS: 1097 case RISCV::VREDXOR_VS: 1098 case RISCV::VWREDSUM_VS: 1099 case RISCV::VWREDSUMU_VS: 1100 case RISCV::VFREDMAX_VS: 1101 case RISCV::VFREDMIN_VS: 1102 case RISCV::VFREDOSUM_VS: 1103 case RISCV::VFREDUSUM_VS: 1104 case RISCV::VFWREDOSUM_VS: 1105 case RISCV::VFWREDUSUM_VS: 1106 return MO.getOperandNo() == 3; 1107 case RISCV::VMV_X_S: 1108 case RISCV::VFMV_F_S: 1109 return MO.getOperandNo() == 1; 1110 default: 1111 return false; 1112 } 1113 } 1114 1115 /// Return true if MI may read elements past VL. 1116 static bool mayReadPastVL(const MachineInstr &MI) { 1117 const RISCVVPseudosTable::PseudoInfo *RVV = 1118 RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); 1119 if (!RVV) 1120 return true; 1121 1122 switch (RVV->BaseInstr) { 1123 // vslidedown instructions may read elements past VL. They are handled 1124 // according to current tail policy. 1125 case RISCV::VSLIDEDOWN_VI: 1126 case RISCV::VSLIDEDOWN_VX: 1127 case RISCV::VSLIDE1DOWN_VX: 1128 case RISCV::VFSLIDE1DOWN_VF: 1129 1130 // vrgather instructions may read the source vector at any index < VLMAX, 1131 // regardless of VL. 1132 case RISCV::VRGATHER_VI: 1133 case RISCV::VRGATHER_VV: 1134 case RISCV::VRGATHER_VX: 1135 case RISCV::VRGATHEREI16_VV: 1136 return true; 1137 1138 default: 1139 return false; 1140 } 1141 } 1142 1143 bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { 1144 const MCInstrDesc &Desc = MI.getDesc(); 1145 if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) 1146 return false; 1147 if (MI.getNumDefs() != 1) 1148 return false; 1149 1150 if (MI.mayRaiseFPException()) { 1151 LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); 1152 return false; 1153 } 1154 1155 // Some instructions that produce vectors have semantics that make it more 1156 // difficult to determine whether the VL can be reduced. For example, some 1157 // instructions, such as reductions, may write lanes past VL to a scalar 1158 // register. Other instructions, such as some loads or stores, may write 1159 // lower lanes using data from higher lanes. There may be other complex 1160 // semantics not mentioned here that make it hard to determine whether 1161 // the VL can be optimized. As a result, a white-list of supported 1162 // instructions is used. Over time, more instructions can be supported 1163 // upon careful examination of their semantics under the logic in this 1164 // optimization. 1165 // TODO: Use a better approach than a white-list, such as adding 1166 // properties to instructions using something like TSFlags. 1167 if (!isSupportedInstr(MI)) { 1168 LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); 1169 return false; 1170 } 1171 1172 assert(MI.getOperand(0).isReg() && 1173 isVectorRegClass(MI.getOperand(0).getReg(), MRI) && 1174 "All supported instructions produce a vector register result"); 1175 1176 LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n"); 1177 return true; 1178 } 1179 1180 std::optional<MachineOperand> 1181 RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { 1182 const MachineInstr &UserMI = *UserOp.getParent(); 1183 const MCInstrDesc &Desc = UserMI.getDesc(); 1184 1185 if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { 1186 LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" 1187 " use VLMAX\n"); 1188 return std::nullopt; 1189 } 1190 1191 // Instructions like reductions may use a vector register as a scalar 1192 // register. In this case, we should treat it as only reading the first lane. 1193 if (isVectorOpUsedAsScalarOp(UserOp)) { 1194 [[maybe_unused]] Register R = UserOp.getReg(); 1195 [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); 1196 assert(RISCV::VRRegClass.hasSubClassEq(RC) && 1197 "Expect LMUL 1 register class for vector as scalar operands!"); 1198 LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); 1199 1200 return MachineOperand::CreateImm(1); 1201 } 1202 1203 unsigned VLOpNum = RISCVII::getVLOpNum(Desc); 1204 const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); 1205 // Looking for an immediate or a register VL that isn't X0. 1206 assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && 1207 "Did not expect X0 VL"); 1208 1209 // If we know the demanded VL of UserMI, then we can reduce the VL it 1210 // requires. 1211 if (auto DemandedVL = DemandedVLs[&UserMI]) { 1212 assert(isCandidate(UserMI)); 1213 if (RISCV::isVLKnownLE(*DemandedVL, VLOp)) 1214 return DemandedVL; 1215 } 1216 1217 return VLOp; 1218 } 1219 1220 std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) { 1221 std::optional<MachineOperand> CommonVL; 1222 for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { 1223 const MachineInstr &UserMI = *UserOp.getParent(); 1224 LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); 1225 if (mayReadPastVL(UserMI)) { 1226 LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); 1227 return std::nullopt; 1228 } 1229 1230 // If used as a passthru, elements past VL will be read. 1231 if (UserOp.isTied()) { 1232 LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); 1233 return std::nullopt; 1234 } 1235 1236 auto VLOp = getMinimumVLForUser(UserOp); 1237 if (!VLOp) 1238 return std::nullopt; 1239 1240 // Use the largest VL among all the users. If we cannot determine this 1241 // statically, then we cannot optimize the VL. 1242 if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { 1243 CommonVL = *VLOp; 1244 LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); 1245 } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { 1246 LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); 1247 return std::nullopt; 1248 } 1249 1250 if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { 1251 LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); 1252 return std::nullopt; 1253 } 1254 1255 std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); 1256 std::optional<OperandInfo> ProducerInfo = 1257 getOperandInfo(MI.getOperand(0), MRI); 1258 if (!ConsumerInfo || !ProducerInfo) { 1259 LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); 1260 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); 1261 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); 1262 return std::nullopt; 1263 } 1264 1265 // If the operand is used as a scalar operand, then the EEW must be 1266 // compatible. Otherwise, the EMUL *and* EEW must be compatible. 1267 bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); 1268 if ((IsVectorOpUsedAsScalarOp && 1269 !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || 1270 (!IsVectorOpUsedAsScalarOp && 1271 !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { 1272 LLVM_DEBUG( 1273 dbgs() 1274 << " Abort due to incompatible information for EMUL or EEW.\n"); 1275 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); 1276 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); 1277 return std::nullopt; 1278 } 1279 } 1280 1281 return CommonVL; 1282 } 1283 1284 bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) { 1285 LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); 1286 1287 unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); 1288 MachineOperand &VLOp = MI.getOperand(VLOpNum); 1289 1290 // If the VL is 1, then there is no need to reduce it. This is an 1291 // optimization, not needed to preserve correctness. 1292 if (VLOp.isImm() && VLOp.getImm() == 1) { 1293 LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n"); 1294 return false; 1295 } 1296 1297 auto CommonVL = DemandedVLs[&MI]; 1298 if (!CommonVL) 1299 return false; 1300 1301 assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && 1302 "Expected VL to be an Imm or virtual Reg"); 1303 1304 if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { 1305 LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); 1306 return false; 1307 } 1308 1309 if (CommonVL->isIdenticalTo(VLOp)) { 1310 LLVM_DEBUG( 1311 dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n"); 1312 return false; 1313 } 1314 1315 if (CommonVL->isImm()) { 1316 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " 1317 << CommonVL->getImm() << " for " << MI << "\n"); 1318 VLOp.ChangeToImmediate(CommonVL->getImm()); 1319 return true; 1320 } 1321 const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); 1322 if (!MDT->dominates(VLMI, &MI)) 1323 return false; 1324 LLVM_DEBUG( 1325 dbgs() << " Reduce VL from " << VLOp << " to " 1326 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) 1327 << " for " << MI << "\n"); 1328 1329 // All our checks passed. We can reduce VL. 1330 VLOp.ChangeToRegister(CommonVL->getReg(), false); 1331 return true; 1332 } 1333 1334 bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { 1335 if (skipFunction(MF.getFunction())) 1336 return false; 1337 1338 MRI = &MF.getRegInfo(); 1339 MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 1340 1341 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 1342 if (!ST.hasVInstructions()) 1343 return false; 1344 1345 // For each instruction that defines a vector, compute what VL its 1346 // downstream users demand. 1347 for (MachineBasicBlock *MBB : post_order(&MF)) { 1348 assert(MDT->isReachableFromEntry(MBB)); 1349 for (MachineInstr &MI : reverse(*MBB)) { 1350 if (!isCandidate(MI)) 1351 continue; 1352 DemandedVLs.insert({&MI, checkUsers(MI)}); 1353 } 1354 } 1355 1356 // Then go through and see if we can reduce the VL of any instructions to 1357 // only what's demanded. 1358 bool MadeChange = false; 1359 for (MachineBasicBlock &MBB : MF) { 1360 // Avoid unreachable blocks as they have degenerate dominance 1361 if (!MDT->isReachableFromEntry(&MBB)) 1362 continue; 1363 1364 for (auto &MI : reverse(MBB)) { 1365 if (!isCandidate(MI)) 1366 continue; 1367 if (!tryReduceVL(MI)) 1368 continue; 1369 MadeChange = true; 1370 } 1371 } 1372 1373 return MadeChange; 1374 } 1375