xref: /llvm-project/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp (revision 8675cd3facc063673c47ed585bd4b7013119fe1f)
1 //===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass reduces the VL where possible at the MI level, before VSETVLI
10 // instructions are inserted.
11 //
12 // The purpose of this optimization is to make the VL argument, for instructions
13 // that have a VL argument, as small as possible. This is implemented by
14 // visiting each instruction in reverse order and checking that if it has a VL
15 // argument, whether the VL can be reduced.
16 //
17 //===---------------------------------------------------------------------===//
18 
19 #include "RISCV.h"
20 #include "RISCVSubtarget.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/CodeGen/MachineDominators.h"
23 #include "llvm/CodeGen/MachineFunctionPass.h"
24 #include "llvm/InitializePasses.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "riscv-vl-optimizer"
29 #define PASS_NAME "RISC-V VL Optimizer"
30 
31 namespace {
32 
33 class RISCVVLOptimizer : public MachineFunctionPass {
34   const MachineRegisterInfo *MRI;
35   const MachineDominatorTree *MDT;
36 
37 public:
38   static char ID;
39 
40   RISCVVLOptimizer() : MachineFunctionPass(ID) {}
41 
42   bool runOnMachineFunction(MachineFunction &MF) override;
43 
44   void getAnalysisUsage(AnalysisUsage &AU) const override {
45     AU.setPreservesCFG();
46     AU.addRequired<MachineDominatorTreeWrapperPass>();
47     MachineFunctionPass::getAnalysisUsage(AU);
48   }
49 
50   StringRef getPassName() const override { return PASS_NAME; }
51 
52 private:
53   std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
54   /// Returns the largest common VL MachineOperand that may be used to optimize
55   /// MI. Returns std::nullopt if it failed to find a suitable VL.
56   std::optional<MachineOperand> checkUsers(MachineInstr &MI);
57   bool tryReduceVL(MachineInstr &MI);
58   bool isCandidate(const MachineInstr &MI) const;
59 
60   /// For a given instruction, records what elements of it are demanded by
61   /// downstream users.
62   DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
63 };
64 
65 } // end anonymous namespace
66 
67 char RISCVVLOptimizer::ID = 0;
68 INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
69 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
70 INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
71 
72 FunctionPass *llvm::createRISCVVLOptimizerPass() {
73   return new RISCVVLOptimizer();
74 }
75 
76 /// Return true if R is a physical or virtual vector register, false otherwise.
77 static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
78   if (R.isPhysical())
79     return RISCV::VRRegClass.contains(R);
80   const TargetRegisterClass *RC = MRI->getRegClass(R);
81   return RISCVRI::isVRegClass(RC->TSFlags);
82 }
83 
84 /// Represents the EMUL and EEW of a MachineOperand.
85 struct OperandInfo {
86   // Represent as 1,2,4,8, ... and fractional indicator. This is because
87   // EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
88   // For example, a mask operand can have an EMUL less than MF8.
89   std::optional<std::pair<unsigned, bool>> EMUL;
90 
91   unsigned Log2EEW;
92 
93   OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
94       : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}
95 
96   OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
97       : EMUL(EMUL), Log2EEW(Log2EEW) {}
98 
99   OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
100 
101   OperandInfo() = delete;
102 
103   static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
104     return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first &&
105            A.EMUL->second == B.EMUL->second;
106   }
107 
108   static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
109     return A.Log2EEW == B.Log2EEW;
110   }
111 
112   void print(raw_ostream &OS) const {
113     if (EMUL) {
114       OS << "EMUL: m";
115       if (EMUL->second)
116         OS << "f";
117       OS << EMUL->first;
118     } else
119       OS << "EMUL: unknown\n";
120     OS << ", EEW: " << (1 << Log2EEW);
121   }
122 };
123 
124 LLVM_ATTRIBUTE_UNUSED
125 static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
126   OI.print(OS);
127   return OS;
128 }
129 
130 LLVM_ATTRIBUTE_UNUSED
131 static raw_ostream &operator<<(raw_ostream &OS,
132                                const std::optional<OperandInfo> &OI) {
133   if (OI)
134     OI->print(OS);
135   else
136     OS << "nullopt";
137   return OS;
138 }
139 
140 namespace llvm {
141 namespace RISCVVType {
142 /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
143 /// SEW are from the TSFlags of MI.
144 static std::pair<unsigned, bool>
145 getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
146   RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
147   auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
148   unsigned MILog2SEW =
149       MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
150 
151   // Mask instructions will have 0 as the SEW operand. But the LMUL of these
152   // instructions is calculated is as if the SEW operand was 3 (e8).
153   if (MILog2SEW == 0)
154     MILog2SEW = 3;
155 
156   unsigned MISEW = 1 << MILog2SEW;
157 
158   unsigned EEW = 1 << Log2EEW;
159   // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
160   // to put fraction in simplest form.
161   unsigned Num = EEW, Denom = MISEW;
162   int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
163                                : std::gcd(Num * MILMUL, Denom);
164   Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
165   Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
166   return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
167 }
168 } // end namespace RISCVVType
169 } // end namespace llvm
170 
171 /// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
172 /// SEW comes from TSFlags of MI.
173 static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
174                                               const MachineInstr &MI,
175                                               const MachineOperand &MO) {
176   unsigned MILog2SEW =
177       MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
178 
179   if (MO.getOperandNo() == 0)
180     return MILog2SEW;
181 
182   unsigned MISEW = 1 << MILog2SEW;
183   unsigned EEW = MISEW / Factor;
184   unsigned Log2EEW = Log2_32(EEW);
185 
186   return Log2EEW;
187 }
188 
189 /// Check whether MO is a mask operand of MI.
190 static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
191                           const MachineRegisterInfo *MRI) {
192 
193   if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
194     return false;
195 
196   const MCInstrDesc &Desc = MI.getDesc();
197   return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
198 }
199 
200 static std::optional<unsigned>
201 getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
202   const MachineInstr &MI = *MO.getParent();
203   const RISCVVPseudosTable::PseudoInfo *RVV =
204       RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
205   assert(RVV && "Could not find MI in PseudoTable");
206 
207   // MI has a SEW associated with it. The RVV specification defines
208   // the EEW of each operand and definition in relation to MI.SEW.
209   unsigned MILog2SEW =
210       MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
211 
212   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
213   const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags);
214 
215   bool IsMODef = MO.getOperandNo() == 0;
216 
217   // All mask operands have EEW=1
218   if (isMaskOperand(MI, MO, MRI))
219     return 0;
220 
221   // switch against BaseInstr to reduce number of cases that need to be
222   // considered.
223   switch (RVV->BaseInstr) {
224 
225   // 6. Configuration-Setting Instructions
226   // Configuration setting instructions do not read or write vector registers
227   case RISCV::VSETIVLI:
228   case RISCV::VSETVL:
229   case RISCV::VSETVLI:
230     llvm_unreachable("Configuration setting instructions do not read or write "
231                      "vector registers");
232 
233   // Vector Loads and Stores
234   // Vector Unit-Stride Instructions
235   // Vector Strided Instructions
236   /// Dest EEW encoded in the instruction
237   case RISCV::VLM_V:
238   case RISCV::VSM_V:
239     return 0;
240   case RISCV::VLE8_V:
241   case RISCV::VSE8_V:
242   case RISCV::VLSE8_V:
243   case RISCV::VSSE8_V:
244     return 3;
245   case RISCV::VLE16_V:
246   case RISCV::VSE16_V:
247   case RISCV::VLSE16_V:
248   case RISCV::VSSE16_V:
249     return 4;
250   case RISCV::VLE32_V:
251   case RISCV::VSE32_V:
252   case RISCV::VLSE32_V:
253   case RISCV::VSSE32_V:
254     return 5;
255   case RISCV::VLE64_V:
256   case RISCV::VSE64_V:
257   case RISCV::VLSE64_V:
258   case RISCV::VSSE64_V:
259     return 6;
260 
261   // Vector Indexed Instructions
262   // vs(o|u)xei<eew>.v
263   // Dest/Data (operand 0) EEW=SEW.  Source EEW=<eew>.
264   case RISCV::VLUXEI8_V:
265   case RISCV::VLOXEI8_V:
266   case RISCV::VSUXEI8_V:
267   case RISCV::VSOXEI8_V: {
268     if (MO.getOperandNo() == 0)
269       return MILog2SEW;
270     return 3;
271   }
272   case RISCV::VLUXEI16_V:
273   case RISCV::VLOXEI16_V:
274   case RISCV::VSUXEI16_V:
275   case RISCV::VSOXEI16_V: {
276     if (MO.getOperandNo() == 0)
277       return MILog2SEW;
278     return 4;
279   }
280   case RISCV::VLUXEI32_V:
281   case RISCV::VLOXEI32_V:
282   case RISCV::VSUXEI32_V:
283   case RISCV::VSOXEI32_V: {
284     if (MO.getOperandNo() == 0)
285       return MILog2SEW;
286     return 5;
287   }
288   case RISCV::VLUXEI64_V:
289   case RISCV::VLOXEI64_V:
290   case RISCV::VSUXEI64_V:
291   case RISCV::VSOXEI64_V: {
292     if (MO.getOperandNo() == 0)
293       return MILog2SEW;
294     return 6;
295   }
296 
297   // Vector Integer Arithmetic Instructions
298   // Vector Single-Width Integer Add and Subtract
299   case RISCV::VADD_VI:
300   case RISCV::VADD_VV:
301   case RISCV::VADD_VX:
302   case RISCV::VSUB_VV:
303   case RISCV::VSUB_VX:
304   case RISCV::VRSUB_VI:
305   case RISCV::VRSUB_VX:
306   // Vector Bitwise Logical Instructions
307   // Vector Single-Width Shift Instructions
308   // EEW=SEW.
309   case RISCV::VAND_VI:
310   case RISCV::VAND_VV:
311   case RISCV::VAND_VX:
312   case RISCV::VOR_VI:
313   case RISCV::VOR_VV:
314   case RISCV::VOR_VX:
315   case RISCV::VXOR_VI:
316   case RISCV::VXOR_VV:
317   case RISCV::VXOR_VX:
318   case RISCV::VSLL_VI:
319   case RISCV::VSLL_VV:
320   case RISCV::VSLL_VX:
321   case RISCV::VSRL_VI:
322   case RISCV::VSRL_VV:
323   case RISCV::VSRL_VX:
324   case RISCV::VSRA_VI:
325   case RISCV::VSRA_VV:
326   case RISCV::VSRA_VX:
327   // Vector Integer Min/Max Instructions
328   // EEW=SEW.
329   case RISCV::VMINU_VV:
330   case RISCV::VMINU_VX:
331   case RISCV::VMIN_VV:
332   case RISCV::VMIN_VX:
333   case RISCV::VMAXU_VV:
334   case RISCV::VMAXU_VX:
335   case RISCV::VMAX_VV:
336   case RISCV::VMAX_VX:
337   // Vector Single-Width Integer Multiply Instructions
338   // Source and Dest EEW=SEW.
339   case RISCV::VMUL_VV:
340   case RISCV::VMUL_VX:
341   case RISCV::VMULH_VV:
342   case RISCV::VMULH_VX:
343   case RISCV::VMULHU_VV:
344   case RISCV::VMULHU_VX:
345   case RISCV::VMULHSU_VV:
346   case RISCV::VMULHSU_VX:
347   // Vector Integer Divide Instructions
348   // EEW=SEW.
349   case RISCV::VDIVU_VV:
350   case RISCV::VDIVU_VX:
351   case RISCV::VDIV_VV:
352   case RISCV::VDIV_VX:
353   case RISCV::VREMU_VV:
354   case RISCV::VREMU_VX:
355   case RISCV::VREM_VV:
356   case RISCV::VREM_VX:
357   // Vector Single-Width Integer Multiply-Add Instructions
358   // EEW=SEW.
359   case RISCV::VMACC_VV:
360   case RISCV::VMACC_VX:
361   case RISCV::VNMSAC_VV:
362   case RISCV::VNMSAC_VX:
363   case RISCV::VMADD_VV:
364   case RISCV::VMADD_VX:
365   case RISCV::VNMSUB_VV:
366   case RISCV::VNMSUB_VX:
367   // Vector Integer Merge Instructions
368   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
369   // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
370   // before this switch.
371   case RISCV::VMERGE_VIM:
372   case RISCV::VMERGE_VVM:
373   case RISCV::VMERGE_VXM:
374   case RISCV::VADC_VIM:
375   case RISCV::VADC_VVM:
376   case RISCV::VADC_VXM:
377   case RISCV::VSBC_VVM:
378   case RISCV::VSBC_VXM:
379   // Vector Integer Move Instructions
380   // Vector Fixed-Point Arithmetic Instructions
381   // Vector Single-Width Saturating Add and Subtract
382   // Vector Single-Width Averaging Add and Subtract
383   // EEW=SEW.
384   case RISCV::VMV_V_I:
385   case RISCV::VMV_V_V:
386   case RISCV::VMV_V_X:
387   case RISCV::VSADDU_VI:
388   case RISCV::VSADDU_VV:
389   case RISCV::VSADDU_VX:
390   case RISCV::VSADD_VI:
391   case RISCV::VSADD_VV:
392   case RISCV::VSADD_VX:
393   case RISCV::VSSUBU_VV:
394   case RISCV::VSSUBU_VX:
395   case RISCV::VSSUB_VV:
396   case RISCV::VSSUB_VX:
397   case RISCV::VAADDU_VV:
398   case RISCV::VAADDU_VX:
399   case RISCV::VAADD_VV:
400   case RISCV::VAADD_VX:
401   case RISCV::VASUBU_VV:
402   case RISCV::VASUBU_VX:
403   case RISCV::VASUB_VV:
404   case RISCV::VASUB_VX:
405   // Vector Single-Width Fractional Multiply with Rounding and Saturation
406   // EEW=SEW. The instruction produces 2*SEW product internally but
407   // saturates to fit into SEW bits.
408   case RISCV::VSMUL_VV:
409   case RISCV::VSMUL_VX:
410   // Vector Single-Width Scaling Shift Instructions
411   // EEW=SEW.
412   case RISCV::VSSRL_VI:
413   case RISCV::VSSRL_VV:
414   case RISCV::VSSRL_VX:
415   case RISCV::VSSRA_VI:
416   case RISCV::VSSRA_VV:
417   case RISCV::VSSRA_VX:
418   // Vector Permutation Instructions
419   // Integer Scalar Move Instructions
420   // Floating-Point Scalar Move Instructions
421   // EEW=SEW.
422   case RISCV::VMV_X_S:
423   case RISCV::VMV_S_X:
424   case RISCV::VFMV_F_S:
425   case RISCV::VFMV_S_F:
426   // Vector Slide Instructions
427   // EEW=SEW.
428   case RISCV::VSLIDEUP_VI:
429   case RISCV::VSLIDEUP_VX:
430   case RISCV::VSLIDEDOWN_VI:
431   case RISCV::VSLIDEDOWN_VX:
432   case RISCV::VSLIDE1UP_VX:
433   case RISCV::VFSLIDE1UP_VF:
434   case RISCV::VSLIDE1DOWN_VX:
435   case RISCV::VFSLIDE1DOWN_VF:
436   // Vector Register Gather Instructions
437   // EEW=SEW. For mask operand, EEW=1.
438   case RISCV::VRGATHER_VI:
439   case RISCV::VRGATHER_VV:
440   case RISCV::VRGATHER_VX:
441   // Vector Compress Instruction
442   // EEW=SEW.
443   case RISCV::VCOMPRESS_VM:
444   // Vector Element Index Instruction
445   case RISCV::VID_V:
446   // Vector Single-Width Floating-Point Add/Subtract Instructions
447   case RISCV::VFADD_VF:
448   case RISCV::VFADD_VV:
449   case RISCV::VFSUB_VF:
450   case RISCV::VFSUB_VV:
451   case RISCV::VFRSUB_VF:
452   // Vector Single-Width Floating-Point Multiply/Divide Instructions
453   case RISCV::VFMUL_VF:
454   case RISCV::VFMUL_VV:
455   case RISCV::VFDIV_VF:
456   case RISCV::VFDIV_VV:
457   case RISCV::VFRDIV_VF:
458   // Vector Floating-Point Square-Root Instruction
459   case RISCV::VFSQRT_V:
460   // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
461   case RISCV::VFRSQRT7_V:
462   // Vector Floating-Point Reciprocal Estimate Instruction
463   case RISCV::VFREC7_V:
464   // Vector Floating-Point MIN/MAX Instructions
465   case RISCV::VFMIN_VF:
466   case RISCV::VFMIN_VV:
467   case RISCV::VFMAX_VF:
468   case RISCV::VFMAX_VV:
469   // Vector Floating-Point Sign-Injection Instructions
470   case RISCV::VFSGNJ_VF:
471   case RISCV::VFSGNJ_VV:
472   case RISCV::VFSGNJN_VV:
473   case RISCV::VFSGNJN_VF:
474   case RISCV::VFSGNJX_VF:
475   case RISCV::VFSGNJX_VV:
476   // Vector Floating-Point Classify Instruction
477   case RISCV::VFCLASS_V:
478   // Vector Floating-Point Move Instruction
479   case RISCV::VFMV_V_F:
480   // Single-Width Floating-Point/Integer Type-Convert Instructions
481   case RISCV::VFCVT_XU_F_V:
482   case RISCV::VFCVT_X_F_V:
483   case RISCV::VFCVT_RTZ_XU_F_V:
484   case RISCV::VFCVT_RTZ_X_F_V:
485   case RISCV::VFCVT_F_XU_V:
486   case RISCV::VFCVT_F_X_V:
487   // Vector Floating-Point Merge Instruction
488   case RISCV::VFMERGE_VFM:
489   // Vector count population in mask vcpop.m
490   // vfirst find-first-set mask bit
491   case RISCV::VCPOP_M:
492   case RISCV::VFIRST_M:
493     return MILog2SEW;
494 
495   // Vector Widening Integer Add/Subtract
496   // Def uses EEW=2*SEW . Operands use EEW=SEW.
497   case RISCV::VWADDU_VV:
498   case RISCV::VWADDU_VX:
499   case RISCV::VWSUBU_VV:
500   case RISCV::VWSUBU_VX:
501   case RISCV::VWADD_VV:
502   case RISCV::VWADD_VX:
503   case RISCV::VWSUB_VV:
504   case RISCV::VWSUB_VX:
505   case RISCV::VWSLL_VI:
506   // Vector Widening Integer Multiply Instructions
507   // Destination EEW=2*SEW. Source EEW=SEW.
508   case RISCV::VWMUL_VV:
509   case RISCV::VWMUL_VX:
510   case RISCV::VWMULSU_VV:
511   case RISCV::VWMULSU_VX:
512   case RISCV::VWMULU_VV:
513   case RISCV::VWMULU_VX:
514   // Vector Widening Integer Multiply-Add Instructions
515   // Destination EEW=2*SEW. Source EEW=SEW.
516   // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
517   // is then added to the 2*SEW-bit Dest. These instructions never have a
518   // passthru operand.
519   case RISCV::VWMACCU_VV:
520   case RISCV::VWMACCU_VX:
521   case RISCV::VWMACC_VV:
522   case RISCV::VWMACC_VX:
523   case RISCV::VWMACCSU_VV:
524   case RISCV::VWMACCSU_VX:
525   case RISCV::VWMACCUS_VX:
526   // Vector Widening Floating-Point Fused Multiply-Add Instructions
527   case RISCV::VFWMACC_VF:
528   case RISCV::VFWMACC_VV:
529   case RISCV::VFWNMACC_VF:
530   case RISCV::VFWNMACC_VV:
531   case RISCV::VFWMSAC_VF:
532   case RISCV::VFWMSAC_VV:
533   case RISCV::VFWNMSAC_VF:
534   case RISCV::VFWNMSAC_VV:
535   // Vector Widening Floating-Point Add/Subtract Instructions
536   // Dest EEW=2*SEW. Source EEW=SEW.
537   case RISCV::VFWADD_VV:
538   case RISCV::VFWADD_VF:
539   case RISCV::VFWSUB_VV:
540   case RISCV::VFWSUB_VF:
541   // Vector Widening Floating-Point Multiply
542   case RISCV::VFWMUL_VF:
543   case RISCV::VFWMUL_VV:
544   // Widening Floating-Point/Integer Type-Convert Instructions
545   case RISCV::VFWCVT_XU_F_V:
546   case RISCV::VFWCVT_X_F_V:
547   case RISCV::VFWCVT_RTZ_XU_F_V:
548   case RISCV::VFWCVT_RTZ_X_F_V:
549   case RISCV::VFWCVT_F_XU_V:
550   case RISCV::VFWCVT_F_X_V:
551   case RISCV::VFWCVT_F_F_V:
552   case RISCV::VFWCVTBF16_F_F_V:
553     return IsMODef ? MILog2SEW + 1 : MILog2SEW;
554 
555   // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
556   case RISCV::VWADDU_WV:
557   case RISCV::VWADDU_WX:
558   case RISCV::VWSUBU_WV:
559   case RISCV::VWSUBU_WX:
560   case RISCV::VWADD_WV:
561   case RISCV::VWADD_WX:
562   case RISCV::VWSUB_WV:
563   case RISCV::VWSUB_WX:
564   // Vector Widening Floating-Point Add/Subtract Instructions
565   case RISCV::VFWADD_WF:
566   case RISCV::VFWADD_WV:
567   case RISCV::VFWSUB_WF:
568   case RISCV::VFWSUB_WV: {
569     bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
570                                           : MO.getOperandNo() == 1;
571     bool TwoTimes = IsMODef || IsOp1;
572     return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
573   }
574 
575   // Vector Integer Extension
576   case RISCV::VZEXT_VF2:
577   case RISCV::VSEXT_VF2:
578     return getIntegerExtensionOperandEEW(2, MI, MO);
579   case RISCV::VZEXT_VF4:
580   case RISCV::VSEXT_VF4:
581     return getIntegerExtensionOperandEEW(4, MI, MO);
582   case RISCV::VZEXT_VF8:
583   case RISCV::VSEXT_VF8:
584     return getIntegerExtensionOperandEEW(8, MI, MO);
585 
586   // Vector Narrowing Integer Right Shift Instructions
587   // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
588   case RISCV::VNSRL_WX:
589   case RISCV::VNSRL_WI:
590   case RISCV::VNSRL_WV:
591   case RISCV::VNSRA_WI:
592   case RISCV::VNSRA_WV:
593   case RISCV::VNSRA_WX:
594   // Vector Narrowing Fixed-Point Clip Instructions
595   // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
596   case RISCV::VNCLIPU_WI:
597   case RISCV::VNCLIPU_WV:
598   case RISCV::VNCLIPU_WX:
599   case RISCV::VNCLIP_WI:
600   case RISCV::VNCLIP_WV:
601   case RISCV::VNCLIP_WX:
602   // Narrowing Floating-Point/Integer Type-Convert Instructions
603   case RISCV::VFNCVT_XU_F_W:
604   case RISCV::VFNCVT_X_F_W:
605   case RISCV::VFNCVT_RTZ_XU_F_W:
606   case RISCV::VFNCVT_RTZ_X_F_W:
607   case RISCV::VFNCVT_F_XU_W:
608   case RISCV::VFNCVT_F_X_W:
609   case RISCV::VFNCVT_F_F_W:
610   case RISCV::VFNCVT_ROD_F_F_W:
611   case RISCV::VFNCVTBF16_F_F_W: {
612     assert(!IsTied);
613     bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
614     bool TwoTimes = IsOp1;
615     return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
616   }
617 
618   // Vector Mask Instructions
619   // Vector Mask-Register Logical Instructions
620   // vmsbf.m set-before-first mask bit
621   // vmsif.m set-including-first mask bit
622   // vmsof.m set-only-first mask bit
623   // EEW=1
624   // We handle the cases when operand is a v0 mask operand above the switch,
625   // but these instructions may use non-v0 mask operands and need to be handled
626   // specifically.
627   case RISCV::VMAND_MM:
628   case RISCV::VMNAND_MM:
629   case RISCV::VMANDN_MM:
630   case RISCV::VMXOR_MM:
631   case RISCV::VMOR_MM:
632   case RISCV::VMNOR_MM:
633   case RISCV::VMORN_MM:
634   case RISCV::VMXNOR_MM:
635   case RISCV::VMSBF_M:
636   case RISCV::VMSIF_M:
637   case RISCV::VMSOF_M: {
638     return MILog2SEW;
639   }
640 
641   // Vector Iota Instruction
642   // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
643   // before this switch.
644   case RISCV::VIOTA_M: {
645     if (IsMODef || MO.getOperandNo() == 1)
646       return MILog2SEW;
647     return 0;
648   }
649 
650   // Vector Integer Compare Instructions
651   // Dest EEW=1. Source EEW=SEW.
652   case RISCV::VMSEQ_VI:
653   case RISCV::VMSEQ_VV:
654   case RISCV::VMSEQ_VX:
655   case RISCV::VMSNE_VI:
656   case RISCV::VMSNE_VV:
657   case RISCV::VMSNE_VX:
658   case RISCV::VMSLTU_VV:
659   case RISCV::VMSLTU_VX:
660   case RISCV::VMSLT_VV:
661   case RISCV::VMSLT_VX:
662   case RISCV::VMSLEU_VV:
663   case RISCV::VMSLEU_VI:
664   case RISCV::VMSLEU_VX:
665   case RISCV::VMSLE_VV:
666   case RISCV::VMSLE_VI:
667   case RISCV::VMSLE_VX:
668   case RISCV::VMSGTU_VI:
669   case RISCV::VMSGTU_VX:
670   case RISCV::VMSGT_VI:
671   case RISCV::VMSGT_VX:
672   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
673   // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
674   case RISCV::VMADC_VIM:
675   case RISCV::VMADC_VVM:
676   case RISCV::VMADC_VXM:
677   case RISCV::VMSBC_VVM:
678   case RISCV::VMSBC_VXM:
679   // Dest EEW=1. Source EEW=SEW.
680   case RISCV::VMADC_VV:
681   case RISCV::VMADC_VI:
682   case RISCV::VMADC_VX:
683   case RISCV::VMSBC_VV:
684   case RISCV::VMSBC_VX:
685   // 13.13. Vector Floating-Point Compare Instructions
686   // Dest EEW=1. Source EEW=SEW
687   case RISCV::VMFEQ_VF:
688   case RISCV::VMFEQ_VV:
689   case RISCV::VMFNE_VF:
690   case RISCV::VMFNE_VV:
691   case RISCV::VMFLT_VF:
692   case RISCV::VMFLT_VV:
693   case RISCV::VMFLE_VF:
694   case RISCV::VMFLE_VV:
695   case RISCV::VMFGT_VF:
696   case RISCV::VMFGE_VF: {
697     if (IsMODef)
698       return 0;
699     return MILog2SEW;
700   }
701 
702   // Vector Reduction Operations
703   // Vector Single-Width Integer Reduction Instructions
704   case RISCV::VREDAND_VS:
705   case RISCV::VREDMAX_VS:
706   case RISCV::VREDMAXU_VS:
707   case RISCV::VREDMIN_VS:
708   case RISCV::VREDMINU_VS:
709   case RISCV::VREDOR_VS:
710   case RISCV::VREDSUM_VS:
711   case RISCV::VREDXOR_VS:
712   // Vector Single-Width Floating-Point Reduction Instructions
713   case RISCV::VFREDMAX_VS:
714   case RISCV::VFREDMIN_VS:
715   case RISCV::VFREDOSUM_VS:
716   case RISCV::VFREDUSUM_VS: {
717     return MILog2SEW;
718   }
719 
720   // Vector Widening Integer Reduction Instructions
721   // The Dest and VS1 read only element 0 for the vector register. Return
722   // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
723   case RISCV::VWREDSUM_VS:
724   case RISCV::VWREDSUMU_VS:
725   // Vector Widening Floating-Point Reduction Instructions
726   case RISCV::VFWREDOSUM_VS:
727   case RISCV::VFWREDUSUM_VS: {
728     bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
729     return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
730   }
731 
732   default:
733     return std::nullopt;
734   }
735 }
736 
737 static std::optional<OperandInfo>
738 getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
739   const MachineInstr &MI = *MO.getParent();
740   const RISCVVPseudosTable::PseudoInfo *RVV =
741       RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
742   assert(RVV && "Could not find MI in PseudoTable");
743 
744   std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
745   if (!Log2EEW)
746     return std::nullopt;
747 
748   switch (RVV->BaseInstr) {
749   // Vector Reduction Operations
750   // Vector Single-Width Integer Reduction Instructions
751   // Vector Widening Integer Reduction Instructions
752   // Vector Widening Floating-Point Reduction Instructions
753   // The Dest and VS1 only read element 0 of the vector register. Return just
754   // the EEW for these.
755   case RISCV::VREDAND_VS:
756   case RISCV::VREDMAX_VS:
757   case RISCV::VREDMAXU_VS:
758   case RISCV::VREDMIN_VS:
759   case RISCV::VREDMINU_VS:
760   case RISCV::VREDOR_VS:
761   case RISCV::VREDSUM_VS:
762   case RISCV::VREDXOR_VS:
763   case RISCV::VWREDSUM_VS:
764   case RISCV::VWREDSUMU_VS:
765   case RISCV::VFWREDOSUM_VS:
766   case RISCV::VFWREDUSUM_VS:
767     if (MO.getOperandNo() != 2)
768       return OperandInfo(*Log2EEW);
769     break;
770   };
771 
772   // All others have EMUL=EEW/SEW*LMUL
773   return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI),
774                      *Log2EEW);
775 }
776 
777 /// Return true if this optimization should consider MI for VL reduction. This
778 /// white-list approach simplifies this optimization for instructions that may
779 /// have more complex semantics with relation to how it uses VL.
780 static bool isSupportedInstr(const MachineInstr &MI) {
781   const RISCVVPseudosTable::PseudoInfo *RVV =
782       RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
783 
784   if (!RVV)
785     return false;
786 
787   switch (RVV->BaseInstr) {
788   // Vector Unit-Stride Instructions
789   // Vector Strided Instructions
790   case RISCV::VLM_V:
791   case RISCV::VLE8_V:
792   case RISCV::VLSE8_V:
793   case RISCV::VLE16_V:
794   case RISCV::VLSE16_V:
795   case RISCV::VLE32_V:
796   case RISCV::VLSE32_V:
797   case RISCV::VLE64_V:
798   case RISCV::VLSE64_V:
799   // Vector Indexed Instructions
800   case RISCV::VLUXEI8_V:
801   case RISCV::VLOXEI8_V:
802   case RISCV::VLUXEI16_V:
803   case RISCV::VLOXEI16_V:
804   case RISCV::VLUXEI32_V:
805   case RISCV::VLOXEI32_V:
806   case RISCV::VLUXEI64_V:
807   case RISCV::VLOXEI64_V: {
808     for (const MachineMemOperand *MMO : MI.memoperands())
809       if (MMO->isVolatile())
810         return false;
811     return true;
812   }
813 
814   // Vector Single-Width Integer Add and Subtract
815   case RISCV::VADD_VI:
816   case RISCV::VADD_VV:
817   case RISCV::VADD_VX:
818   case RISCV::VSUB_VV:
819   case RISCV::VSUB_VX:
820   case RISCV::VRSUB_VI:
821   case RISCV::VRSUB_VX:
822   // Vector Bitwise Logical Instructions
823   // Vector Single-Width Shift Instructions
824   case RISCV::VAND_VI:
825   case RISCV::VAND_VV:
826   case RISCV::VAND_VX:
827   case RISCV::VOR_VI:
828   case RISCV::VOR_VV:
829   case RISCV::VOR_VX:
830   case RISCV::VXOR_VI:
831   case RISCV::VXOR_VV:
832   case RISCV::VXOR_VX:
833   case RISCV::VSLL_VI:
834   case RISCV::VSLL_VV:
835   case RISCV::VSLL_VX:
836   case RISCV::VSRL_VI:
837   case RISCV::VSRL_VV:
838   case RISCV::VSRL_VX:
839   case RISCV::VSRA_VI:
840   case RISCV::VSRA_VV:
841   case RISCV::VSRA_VX:
842   // Vector Widening Integer Add/Subtract
843   case RISCV::VWADDU_VV:
844   case RISCV::VWADDU_VX:
845   case RISCV::VWSUBU_VV:
846   case RISCV::VWSUBU_VX:
847   case RISCV::VWADD_VV:
848   case RISCV::VWADD_VX:
849   case RISCV::VWSUB_VV:
850   case RISCV::VWSUB_VX:
851   case RISCV::VWADDU_WV:
852   case RISCV::VWADDU_WX:
853   case RISCV::VWSUBU_WV:
854   case RISCV::VWSUBU_WX:
855   case RISCV::VWADD_WV:
856   case RISCV::VWADD_WX:
857   case RISCV::VWSUB_WV:
858   case RISCV::VWSUB_WX:
859   // Vector Integer Extension
860   case RISCV::VZEXT_VF2:
861   case RISCV::VSEXT_VF2:
862   case RISCV::VZEXT_VF4:
863   case RISCV::VSEXT_VF4:
864   case RISCV::VZEXT_VF8:
865   case RISCV::VSEXT_VF8:
866   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
867   // FIXME: Add support
868   case RISCV::VMADC_VV:
869   case RISCV::VMADC_VI:
870   case RISCV::VMADC_VX:
871   case RISCV::VMSBC_VV:
872   case RISCV::VMSBC_VX:
873   // Vector Narrowing Integer Right Shift Instructions
874   case RISCV::VNSRL_WX:
875   case RISCV::VNSRL_WI:
876   case RISCV::VNSRL_WV:
877   case RISCV::VNSRA_WI:
878   case RISCV::VNSRA_WV:
879   case RISCV::VNSRA_WX:
880   // Vector Integer Compare Instructions
881   case RISCV::VMSEQ_VI:
882   case RISCV::VMSEQ_VV:
883   case RISCV::VMSEQ_VX:
884   case RISCV::VMSNE_VI:
885   case RISCV::VMSNE_VV:
886   case RISCV::VMSNE_VX:
887   case RISCV::VMSLTU_VV:
888   case RISCV::VMSLTU_VX:
889   case RISCV::VMSLT_VV:
890   case RISCV::VMSLT_VX:
891   case RISCV::VMSLEU_VV:
892   case RISCV::VMSLEU_VI:
893   case RISCV::VMSLEU_VX:
894   case RISCV::VMSLE_VV:
895   case RISCV::VMSLE_VI:
896   case RISCV::VMSLE_VX:
897   case RISCV::VMSGTU_VI:
898   case RISCV::VMSGTU_VX:
899   case RISCV::VMSGT_VI:
900   case RISCV::VMSGT_VX:
901   // Vector Integer Min/Max Instructions
902   case RISCV::VMINU_VV:
903   case RISCV::VMINU_VX:
904   case RISCV::VMIN_VV:
905   case RISCV::VMIN_VX:
906   case RISCV::VMAXU_VV:
907   case RISCV::VMAXU_VX:
908   case RISCV::VMAX_VV:
909   case RISCV::VMAX_VX:
910   // Vector Single-Width Integer Multiply Instructions
911   case RISCV::VMUL_VV:
912   case RISCV::VMUL_VX:
913   case RISCV::VMULH_VV:
914   case RISCV::VMULH_VX:
915   case RISCV::VMULHU_VV:
916   case RISCV::VMULHU_VX:
917   case RISCV::VMULHSU_VV:
918   case RISCV::VMULHSU_VX:
919   // Vector Integer Divide Instructions
920   case RISCV::VDIVU_VV:
921   case RISCV::VDIVU_VX:
922   case RISCV::VDIV_VV:
923   case RISCV::VDIV_VX:
924   case RISCV::VREMU_VV:
925   case RISCV::VREMU_VX:
926   case RISCV::VREM_VV:
927   case RISCV::VREM_VX:
928   // Vector Widening Integer Multiply Instructions
929   case RISCV::VWMUL_VV:
930   case RISCV::VWMUL_VX:
931   case RISCV::VWMULSU_VV:
932   case RISCV::VWMULSU_VX:
933   case RISCV::VWMULU_VV:
934   case RISCV::VWMULU_VX:
935   // Vector Single-Width Integer Multiply-Add Instructions
936   case RISCV::VMACC_VV:
937   case RISCV::VMACC_VX:
938   case RISCV::VNMSAC_VV:
939   case RISCV::VNMSAC_VX:
940   case RISCV::VMADD_VV:
941   case RISCV::VMADD_VX:
942   case RISCV::VNMSUB_VV:
943   case RISCV::VNMSUB_VX:
944   // Vector Integer Merge Instructions
945   case RISCV::VMERGE_VIM:
946   case RISCV::VMERGE_VVM:
947   case RISCV::VMERGE_VXM:
948   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
949   case RISCV::VADC_VIM:
950   case RISCV::VADC_VVM:
951   case RISCV::VADC_VXM:
952   // Vector Widening Integer Multiply-Add Instructions
953   case RISCV::VWMACCU_VV:
954   case RISCV::VWMACCU_VX:
955   case RISCV::VWMACC_VV:
956   case RISCV::VWMACC_VX:
957   case RISCV::VWMACCSU_VV:
958   case RISCV::VWMACCSU_VX:
959   case RISCV::VWMACCUS_VX:
960   // Vector Integer Merge Instructions
961   // FIXME: Add support
962   // Vector Integer Move Instructions
963   // FIXME: Add support
964   case RISCV::VMV_V_I:
965   case RISCV::VMV_V_X:
966   case RISCV::VMV_V_V:
967   // Vector Single-Width Averaging Add and Subtract
968   case RISCV::VAADDU_VV:
969   case RISCV::VAADDU_VX:
970   case RISCV::VAADD_VV:
971   case RISCV::VAADD_VX:
972   case RISCV::VASUBU_VV:
973   case RISCV::VASUBU_VX:
974   case RISCV::VASUB_VV:
975   case RISCV::VASUB_VX:
976 
977   // Vector Crypto
978   case RISCV::VWSLL_VI:
979 
980   // Vector Mask Instructions
981   // Vector Mask-Register Logical Instructions
982   // vmsbf.m set-before-first mask bit
983   // vmsif.m set-including-first mask bit
984   // vmsof.m set-only-first mask bit
985   // Vector Iota Instruction
986   // Vector Element Index Instruction
987   case RISCV::VMAND_MM:
988   case RISCV::VMNAND_MM:
989   case RISCV::VMANDN_MM:
990   case RISCV::VMXOR_MM:
991   case RISCV::VMOR_MM:
992   case RISCV::VMNOR_MM:
993   case RISCV::VMORN_MM:
994   case RISCV::VMXNOR_MM:
995   case RISCV::VMSBF_M:
996   case RISCV::VMSIF_M:
997   case RISCV::VMSOF_M:
998   case RISCV::VIOTA_M:
999   case RISCV::VID_V:
1000   // Vector Single-Width Floating-Point Add/Subtract Instructions
1001   case RISCV::VFADD_VF:
1002   case RISCV::VFADD_VV:
1003   case RISCV::VFSUB_VF:
1004   case RISCV::VFSUB_VV:
1005   case RISCV::VFRSUB_VF:
1006   // Vector Widening Floating-Point Add/Subtract Instructions
1007   case RISCV::VFWADD_VV:
1008   case RISCV::VFWADD_VF:
1009   case RISCV::VFWSUB_VV:
1010   case RISCV::VFWSUB_VF:
1011   case RISCV::VFWADD_WF:
1012   case RISCV::VFWADD_WV:
1013   case RISCV::VFWSUB_WF:
1014   case RISCV::VFWSUB_WV:
1015   // Vector Single-Width Floating-Point Multiply/Divide Instructions
1016   case RISCV::VFMUL_VF:
1017   case RISCV::VFMUL_VV:
1018   case RISCV::VFDIV_VF:
1019   case RISCV::VFDIV_VV:
1020   case RISCV::VFRDIV_VF:
1021   // Vector Widening Floating-Point Multiply
1022   case RISCV::VFWMUL_VF:
1023   case RISCV::VFWMUL_VV:
1024   // Vector Floating-Point MIN/MAX Instructions
1025   case RISCV::VFMIN_VF:
1026   case RISCV::VFMIN_VV:
1027   case RISCV::VFMAX_VF:
1028   case RISCV::VFMAX_VV:
1029   // Vector Floating-Point Sign-Injection Instructions
1030   case RISCV::VFSGNJ_VF:
1031   case RISCV::VFSGNJ_VV:
1032   case RISCV::VFSGNJN_VV:
1033   case RISCV::VFSGNJN_VF:
1034   case RISCV::VFSGNJX_VF:
1035   case RISCV::VFSGNJX_VV:
1036   // Vector Floating-Point Compare Instructions
1037   case RISCV::VMFEQ_VF:
1038   case RISCV::VMFEQ_VV:
1039   case RISCV::VMFNE_VF:
1040   case RISCV::VMFNE_VV:
1041   case RISCV::VMFLT_VF:
1042   case RISCV::VMFLT_VV:
1043   case RISCV::VMFLE_VF:
1044   case RISCV::VMFLE_VV:
1045   case RISCV::VMFGT_VF:
1046   case RISCV::VMFGE_VF:
1047   // Single-Width Floating-Point/Integer Type-Convert Instructions
1048   case RISCV::VFCVT_XU_F_V:
1049   case RISCV::VFCVT_X_F_V:
1050   case RISCV::VFCVT_RTZ_XU_F_V:
1051   case RISCV::VFCVT_RTZ_X_F_V:
1052   case RISCV::VFCVT_F_XU_V:
1053   case RISCV::VFCVT_F_X_V:
1054   // Widening Floating-Point/Integer Type-Convert Instructions
1055   case RISCV::VFWCVT_XU_F_V:
1056   case RISCV::VFWCVT_X_F_V:
1057   case RISCV::VFWCVT_RTZ_XU_F_V:
1058   case RISCV::VFWCVT_RTZ_X_F_V:
1059   case RISCV::VFWCVT_F_XU_V:
1060   case RISCV::VFWCVT_F_X_V:
1061   case RISCV::VFWCVT_F_F_V:
1062   case RISCV::VFWCVTBF16_F_F_V:
1063   // Narrowing Floating-Point/Integer Type-Convert Instructions
1064   case RISCV::VFNCVT_XU_F_W:
1065   case RISCV::VFNCVT_X_F_W:
1066   case RISCV::VFNCVT_RTZ_XU_F_W:
1067   case RISCV::VFNCVT_RTZ_X_F_W:
1068   case RISCV::VFNCVT_F_XU_W:
1069   case RISCV::VFNCVT_F_X_W:
1070   case RISCV::VFNCVT_F_F_W:
1071   case RISCV::VFNCVT_ROD_F_F_W:
1072   case RISCV::VFNCVTBF16_F_F_W:
1073     return true;
1074   }
1075 
1076   return false;
1077 }
1078 
1079 /// Return true if MO is a vector operand but is used as a scalar operand.
1080 static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) {
1081   MachineInstr *MI = MO.getParent();
1082   const RISCVVPseudosTable::PseudoInfo *RVV =
1083       RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());
1084 
1085   if (!RVV)
1086     return false;
1087 
1088   switch (RVV->BaseInstr) {
1089   // Reductions only use vs1[0] of vs1
1090   case RISCV::VREDAND_VS:
1091   case RISCV::VREDMAX_VS:
1092   case RISCV::VREDMAXU_VS:
1093   case RISCV::VREDMIN_VS:
1094   case RISCV::VREDMINU_VS:
1095   case RISCV::VREDOR_VS:
1096   case RISCV::VREDSUM_VS:
1097   case RISCV::VREDXOR_VS:
1098   case RISCV::VWREDSUM_VS:
1099   case RISCV::VWREDSUMU_VS:
1100   case RISCV::VFREDMAX_VS:
1101   case RISCV::VFREDMIN_VS:
1102   case RISCV::VFREDOSUM_VS:
1103   case RISCV::VFREDUSUM_VS:
1104   case RISCV::VFWREDOSUM_VS:
1105   case RISCV::VFWREDUSUM_VS:
1106     return MO.getOperandNo() == 3;
1107   case RISCV::VMV_X_S:
1108   case RISCV::VFMV_F_S:
1109     return MO.getOperandNo() == 1;
1110   default:
1111     return false;
1112   }
1113 }
1114 
1115 /// Return true if MI may read elements past VL.
1116 static bool mayReadPastVL(const MachineInstr &MI) {
1117   const RISCVVPseudosTable::PseudoInfo *RVV =
1118       RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
1119   if (!RVV)
1120     return true;
1121 
1122   switch (RVV->BaseInstr) {
1123   // vslidedown instructions may read elements past VL. They are handled
1124   // according to current tail policy.
1125   case RISCV::VSLIDEDOWN_VI:
1126   case RISCV::VSLIDEDOWN_VX:
1127   case RISCV::VSLIDE1DOWN_VX:
1128   case RISCV::VFSLIDE1DOWN_VF:
1129 
1130   // vrgather instructions may read the source vector at any index < VLMAX,
1131   // regardless of VL.
1132   case RISCV::VRGATHER_VI:
1133   case RISCV::VRGATHER_VV:
1134   case RISCV::VRGATHER_VX:
1135   case RISCV::VRGATHEREI16_VV:
1136     return true;
1137 
1138   default:
1139     return false;
1140   }
1141 }
1142 
1143 bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1144   const MCInstrDesc &Desc = MI.getDesc();
1145   if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
1146     return false;
1147   if (MI.getNumDefs() != 1)
1148     return false;
1149 
1150   if (MI.mayRaiseFPException()) {
1151     LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
1152     return false;
1153   }
1154 
1155   // Some instructions that produce vectors have semantics that make it more
1156   // difficult to determine whether the VL can be reduced. For example, some
1157   // instructions, such as reductions, may write lanes past VL to a scalar
1158   // register. Other instructions, such as some loads or stores, may write
1159   // lower lanes using data from higher lanes. There may be other complex
1160   // semantics not mentioned here that make it hard to determine whether
1161   // the VL can be optimized. As a result, a white-list of supported
1162   // instructions is used. Over time, more instructions can be supported
1163   // upon careful examination of their semantics under the logic in this
1164   // optimization.
1165   // TODO: Use a better approach than a white-list, such as adding
1166   // properties to instructions using something like TSFlags.
1167   if (!isSupportedInstr(MI)) {
1168     LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
1169     return false;
1170   }
1171 
1172   assert(MI.getOperand(0).isReg() &&
1173          isVectorRegClass(MI.getOperand(0).getReg(), MRI) &&
1174          "All supported instructions produce a vector register result");
1175 
1176   LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
1177   return true;
1178 }
1179 
1180 std::optional<MachineOperand>
1181 RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1182   const MachineInstr &UserMI = *UserOp.getParent();
1183   const MCInstrDesc &Desc = UserMI.getDesc();
1184 
1185   if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
1186     LLVM_DEBUG(dbgs() << "    Abort due to lack of VL, assume that"
1187                          " use VLMAX\n");
1188     return std::nullopt;
1189   }
1190 
1191   // Instructions like reductions may use a vector register as a scalar
1192   // register. In this case, we should treat it as only reading the first lane.
1193   if (isVectorOpUsedAsScalarOp(UserOp)) {
1194     [[maybe_unused]] Register R = UserOp.getReg();
1195     [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R);
1196     assert(RISCV::VRRegClass.hasSubClassEq(RC) &&
1197            "Expect LMUL 1 register class for vector as scalar operands!");
1198     LLVM_DEBUG(dbgs() << "    Used this operand as a scalar operand\n");
1199 
1200     return MachineOperand::CreateImm(1);
1201   }
1202 
1203   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
1204   const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
1205   // Looking for an immediate or a register VL that isn't X0.
1206   assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
1207          "Did not expect X0 VL");
1208 
1209   // If we know the demanded VL of UserMI, then we can reduce the VL it
1210   // requires.
1211   if (auto DemandedVL = DemandedVLs[&UserMI]) {
1212     assert(isCandidate(UserMI));
1213     if (RISCV::isVLKnownLE(*DemandedVL, VLOp))
1214       return DemandedVL;
1215   }
1216 
1217   return VLOp;
1218 }
1219 
1220 std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1221   std::optional<MachineOperand> CommonVL;
1222   for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
1223     const MachineInstr &UserMI = *UserOp.getParent();
1224     LLVM_DEBUG(dbgs() << "  Checking user: " << UserMI << "\n");
1225     if (mayReadPastVL(UserMI)) {
1226       LLVM_DEBUG(dbgs() << "    Abort because used by unsafe instruction\n");
1227       return std::nullopt;
1228     }
1229 
1230     // If used as a passthru, elements past VL will be read.
1231     if (UserOp.isTied()) {
1232       LLVM_DEBUG(dbgs() << "    Abort because user used as tied operand\n");
1233       return std::nullopt;
1234     }
1235 
1236     auto VLOp = getMinimumVLForUser(UserOp);
1237     if (!VLOp)
1238       return std::nullopt;
1239 
1240     // Use the largest VL among all the users. If we cannot determine this
1241     // statically, then we cannot optimize the VL.
1242     if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1243       CommonVL = *VLOp;
1244       LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
1245     } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1246       LLVM_DEBUG(dbgs() << "    Abort because cannot determine a common VL\n");
1247       return std::nullopt;
1248     }
1249 
1250     if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
1251       LLVM_DEBUG(dbgs() << "    Abort due to lack of SEW operand\n");
1252       return std::nullopt;
1253     }
1254 
1255     std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
1256     std::optional<OperandInfo> ProducerInfo =
1257         getOperandInfo(MI.getOperand(0), MRI);
1258     if (!ConsumerInfo || !ProducerInfo) {
1259       LLVM_DEBUG(dbgs() << "    Abort due to unknown operand information.\n");
1260       LLVM_DEBUG(dbgs() << "      ConsumerInfo is: " << ConsumerInfo << "\n");
1261       LLVM_DEBUG(dbgs() << "      ProducerInfo is: " << ProducerInfo << "\n");
1262       return std::nullopt;
1263     }
1264 
1265     // If the operand is used as a scalar operand, then the EEW must be
1266     // compatible. Otherwise, the EMUL *and* EEW must be compatible.
1267     bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp);
1268     if ((IsVectorOpUsedAsScalarOp &&
1269          !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) ||
1270         (!IsVectorOpUsedAsScalarOp &&
1271          !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) {
1272       LLVM_DEBUG(
1273           dbgs()
1274           << "    Abort due to incompatible information for EMUL or EEW.\n");
1275       LLVM_DEBUG(dbgs() << "      ConsumerInfo is: " << ConsumerInfo << "\n");
1276       LLVM_DEBUG(dbgs() << "      ProducerInfo is: " << ProducerInfo << "\n");
1277       return std::nullopt;
1278     }
1279   }
1280 
1281   return CommonVL;
1282 }
1283 
1284 bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
1285   LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
1286 
1287   unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
1288   MachineOperand &VLOp = MI.getOperand(VLOpNum);
1289 
1290   // If the VL is 1, then there is no need to reduce it. This is an
1291   // optimization, not needed to preserve correctness.
1292   if (VLOp.isImm() && VLOp.getImm() == 1) {
1293     LLVM_DEBUG(dbgs() << "  Abort due to VL == 1, no point in reducing.\n");
1294     return false;
1295   }
1296 
1297   auto CommonVL = DemandedVLs[&MI];
1298   if (!CommonVL)
1299     return false;
1300 
1301   assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
1302          "Expected VL to be an Imm or virtual Reg");
1303 
1304   if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
1305     LLVM_DEBUG(dbgs() << "    Abort due to CommonVL not <= VLOp.\n");
1306     return false;
1307   }
1308 
1309   if (CommonVL->isIdenticalTo(VLOp)) {
1310     LLVM_DEBUG(
1311         dbgs() << "    Abort due to CommonVL == VLOp, no point in reducing.\n");
1312     return false;
1313   }
1314 
1315   if (CommonVL->isImm()) {
1316     LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
1317                       << CommonVL->getImm() << " for " << MI << "\n");
1318     VLOp.ChangeToImmediate(CommonVL->getImm());
1319     return true;
1320   }
1321   const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
1322   if (!MDT->dominates(VLMI, &MI))
1323     return false;
1324   LLVM_DEBUG(
1325       dbgs() << "  Reduce VL from " << VLOp << " to "
1326              << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
1327              << " for " << MI << "\n");
1328 
1329   // All our checks passed. We can reduce VL.
1330   VLOp.ChangeToRegister(CommonVL->getReg(), false);
1331   return true;
1332 }
1333 
1334 bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1335   if (skipFunction(MF.getFunction()))
1336     return false;
1337 
1338   MRI = &MF.getRegInfo();
1339   MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
1340 
1341   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1342   if (!ST.hasVInstructions())
1343     return false;
1344 
1345   // For each instruction that defines a vector, compute what VL its
1346   // downstream users demand.
1347   for (MachineBasicBlock *MBB : post_order(&MF)) {
1348     assert(MDT->isReachableFromEntry(MBB));
1349     for (MachineInstr &MI : reverse(*MBB)) {
1350       if (!isCandidate(MI))
1351         continue;
1352       DemandedVLs.insert({&MI, checkUsers(MI)});
1353     }
1354   }
1355 
1356   // Then go through and see if we can reduce the VL of any instructions to
1357   // only what's demanded.
1358   bool MadeChange = false;
1359   for (MachineBasicBlock &MBB : MF) {
1360     // Avoid unreachable blocks as they have degenerate dominance
1361     if (!MDT->isReachableFromEntry(&MBB))
1362       continue;
1363 
1364     for (auto &MI : reverse(MBB)) {
1365       if (!isCandidate(MI))
1366         continue;
1367       if (!tryReduceVL(MI))
1368         continue;
1369       MadeChange = true;
1370     }
1371   }
1372 
1373   return MadeChange;
1374 }
1375