xref: /llvm-project/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp (revision c8ee1164bd6ae2f0a603c53d1d29ad5a3225c5cd)
1 //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs various vector pseudo peephole optimisations after
10 // instruction selection.
11 //
12 // Currently it converts vmerge.vvm to vmv.v.v
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
16 //
17 // And masked pseudos to unmasked pseudos
18 // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19 // ->
20 // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21 //
22 // It also converts AVLs to VLMAX where possible
23 // %vl = VLENB * something
24 // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25 // ->
26 // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #include "RISCV.h"
31 #include "RISCVSubtarget.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/TargetInstrInfo.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "riscv-vector-peephole"
40 
41 namespace {
42 
43 class RISCVVectorPeephole : public MachineFunctionPass {
44 public:
45   static char ID;
46   const TargetInstrInfo *TII;
47   MachineRegisterInfo *MRI;
48   const TargetRegisterInfo *TRI;
49   const RISCVSubtarget *ST;
50   RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51 
52   bool runOnMachineFunction(MachineFunction &MF) override;
53   MachineFunctionProperties getRequiredProperties() const override {
54     return MachineFunctionProperties().set(
55         MachineFunctionProperties::Property::IsSSA);
56   }
57 
58   StringRef getPassName() const override {
59     return "RISC-V Vector Peephole Optimization";
60   }
61 
62 private:
63   bool tryToReduceVL(MachineInstr &MI) const;
64   bool convertToVLMAX(MachineInstr &MI) const;
65   bool convertToWholeRegister(MachineInstr &MI) const;
66   bool convertToUnmasked(MachineInstr &MI) const;
67   bool convertAllOnesVMergeToVMv(MachineInstr &MI) const;
68   bool convertSameMaskVMergeToVMv(MachineInstr &MI);
69   bool foldUndefPassthruVMV_V_V(MachineInstr &MI);
70   bool foldVMV_V_V(MachineInstr &MI);
71 
72   bool hasSameEEW(const MachineInstr &User, const MachineInstr &Src) const;
73   bool isAllOnesMask(const MachineInstr *MaskDef) const;
74   std::optional<unsigned> getConstant(const MachineOperand &VL) const;
75   bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const;
76 
77   /// Maps uses of V0 to the corresponding def of V0.
78   DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
79 };
80 
81 } // namespace
82 
83 char RISCVVectorPeephole::ID = 0;
84 
85 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
86                 false)
87 
88 /// Given \p User that has an input operand with EEW=SEW, which uses the dest
89 /// operand of \p Src with an unknown EEW, return true if their EEWs match.
90 bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
91                                      const MachineInstr &Src) const {
92   unsigned UserLog2SEW =
93       User.getOperand(RISCVII::getSEWOpNum(User.getDesc())).getImm();
94   unsigned SrcLog2SEW =
95       Src.getOperand(RISCVII::getSEWOpNum(Src.getDesc())).getImm();
96   unsigned SrcLog2EEW = RISCV::getDestLog2EEW(
97       TII->get(RISCV::getRVVMCOpcode(Src.getOpcode())), SrcLog2SEW);
98   return SrcLog2EEW == UserLog2SEW;
99 }
100 
101 // Attempt to reduce the VL of an instruction whose sole use is feeding a
102 // instruction with a narrower VL.  This currently works backwards from the
103 // user instruction (which might have a smaller VL).
104 bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
105   // Note that the goal here is a bit multifaceted.
106   // 1) For store's reducing the VL of the value being stored may help to
107   //    reduce VL toggles.  This is somewhat of an artifact of the fact we
108   //    promote arithmetic instructions but VL predicate stores.
109   // 2) For vmv.v.v reducing VL eagerly on the source instruction allows us
110   //    to share code with the foldVMV_V_V transform below.
111   //
112   // Note that to the best of our knowledge, reducing VL is generally not
113   // a significant win on real hardware unless we can also reduce LMUL which
114   // this code doesn't try to do.
115   //
116   // TODO: We can handle a bunch more instructions here, and probably
117   // recurse backwards through operands too.
118   unsigned SrcIdx = 0;
119   switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
120   default:
121     return false;
122   case RISCV::VSE8_V:
123   case RISCV::VSE16_V:
124   case RISCV::VSE32_V:
125   case RISCV::VSE64_V:
126     break;
127   case RISCV::VMV_V_V:
128     SrcIdx = 2;
129     break;
130   case RISCV::VMERGE_VVM:
131     SrcIdx = 3; // TODO: We can also handle the false operand.
132     break;
133   case RISCV::VREDSUM_VS:
134   case RISCV::VREDMAXU_VS:
135   case RISCV::VREDMAX_VS:
136   case RISCV::VREDMINU_VS:
137   case RISCV::VREDMIN_VS:
138   case RISCV::VREDAND_VS:
139   case RISCV::VREDOR_VS:
140   case RISCV::VREDXOR_VS:
141   case RISCV::VWREDSUM_VS:
142   case RISCV::VWREDSUMU_VS:
143   case RISCV::VFREDUSUM_VS:
144   case RISCV::VFREDOSUM_VS:
145   case RISCV::VFREDMAX_VS:
146   case RISCV::VFREDMIN_VS:
147   case RISCV::VFWREDUSUM_VS:
148   case RISCV::VFWREDOSUM_VS:
149     SrcIdx = 2;
150     break;
151   }
152 
153   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
154   if (VL.isImm() && VL.getImm() == RISCV::VLMaxSentinel)
155     return false;
156 
157   Register SrcReg = MI.getOperand(SrcIdx).getReg();
158   // Note: one *use*, not one *user*.
159   if (!MRI->hasOneUse(SrcReg))
160     return false;
161 
162   MachineInstr *Src = MRI->getVRegDef(SrcReg);
163   if (!Src || Src->hasUnmodeledSideEffects() ||
164       Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
165       !RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
166       !RISCVII::hasSEWOp(Src->getDesc().TSFlags))
167     return false;
168 
169   // Src's dest needs to have the same EEW as MI's input.
170   if (!hasSameEEW(MI, *Src))
171     return false;
172 
173   bool ElementsDependOnVL = RISCVII::elementsDependOnVL(
174       TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);
175   if (ElementsDependOnVL || Src->mayRaiseFPException())
176     return false;
177 
178   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
179   if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
180     return false;
181 
182   if (!ensureDominates(VL, *Src))
183     return false;
184 
185   if (VL.isImm())
186     SrcVL.ChangeToImmediate(VL.getImm());
187   else if (VL.isReg())
188     SrcVL.ChangeToRegister(VL.getReg(), false);
189 
190   // TODO: For instructions with a passthru, we could clear the passthru
191   // and tail policy since we've just proven the tail is not demanded.
192   return true;
193 }
194 
195 /// Check if an operand is an immediate or a materialized ADDI $x0, imm.
196 std::optional<unsigned>
197 RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
198   if (VL.isImm())
199     return VL.getImm();
200 
201   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
202   if (!Def || Def->getOpcode() != RISCV::ADDI ||
203       Def->getOperand(1).getReg() != RISCV::X0)
204     return std::nullopt;
205   return Def->getOperand(2).getImm();
206 }
207 
208 /// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
209 bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
210   if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
211       !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
212     return false;
213 
214   auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
215   // Fixed-point value, denominator=8
216   unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
217   unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
218   // A Log2SEW of 0 is an operation on mask registers only
219   unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
220   assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
221   assert(8 * LMULFixed / SEW > 0);
222 
223   // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
224   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
225   if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
226       VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
227     VL.ChangeToImmediate(RISCV::VLMaxSentinel);
228     return true;
229   }
230 
231   // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
232   // it to the VLMAX sentinel value.
233   if (!VL.isReg())
234     return false;
235   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
236   if (!Def)
237     return false;
238 
239   // Fixed-point value, denominator=8
240   uint64_t ScaleFixed = 8;
241   // Check if the VLENB was potentially scaled with slli/srli
242   if (Def->getOpcode() == RISCV::SLLI) {
243     assert(Def->getOperand(2).getImm() < 64);
244     ScaleFixed <<= Def->getOperand(2).getImm();
245     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
246   } else if (Def->getOpcode() == RISCV::SRLI) {
247     assert(Def->getOperand(2).getImm() < 64);
248     ScaleFixed >>= Def->getOperand(2).getImm();
249     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
250   }
251 
252   if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
253     return false;
254 
255   // AVL = (VLENB * Scale)
256   //
257   // VLMAX = (VLENB * 8 * LMUL) / SEW
258   //
259   // AVL == VLMAX
260   // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
261   // -> Scale == (8 * LMUL) / SEW
262   if (ScaleFixed != 8 * LMULFixed / SEW)
263     return false;
264 
265   VL.ChangeToImmediate(RISCV::VLMaxSentinel);
266 
267   return true;
268 }
269 
270 bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
271   assert(MaskDef && MaskDef->isCopy() &&
272          MaskDef->getOperand(0).getReg() == RISCV::V0);
273   Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
274   if (!SrcReg.isVirtual())
275     return false;
276   MaskDef = MRI->getVRegDef(SrcReg);
277   if (!MaskDef)
278     return false;
279 
280   // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
281   // undefined behaviour if it's the wrong bitwidth, so we could choose to
282   // assume that it's all-ones? Same applies to its VL.
283   switch (MaskDef->getOpcode()) {
284   case RISCV::PseudoVMSET_M_B1:
285   case RISCV::PseudoVMSET_M_B2:
286   case RISCV::PseudoVMSET_M_B4:
287   case RISCV::PseudoVMSET_M_B8:
288   case RISCV::PseudoVMSET_M_B16:
289   case RISCV::PseudoVMSET_M_B32:
290   case RISCV::PseudoVMSET_M_B64:
291     return true;
292   default:
293     return false;
294   }
295 }
296 
297 /// Convert unit strided unmasked loads and stores to whole-register equivalents
298 /// to avoid the dependency on $vl and $vtype.
299 ///
300 /// %x = PseudoVLE8_V_M1 %passthru, %ptr, %vlmax, policy
301 /// PseudoVSE8_V_M1 %v, %ptr, %vlmax
302 ///
303 /// ->
304 ///
305 /// %x = VL1RE8_V %ptr
306 /// VS1R_V %v, %ptr
307 bool RISCVVectorPeephole::convertToWholeRegister(MachineInstr &MI) const {
308 #define CASE_WHOLE_REGISTER_LMUL_SEW(lmul, sew)                                \
309   case RISCV::PseudoVLE##sew##_V_M##lmul:                                      \
310     NewOpc = RISCV::VL##lmul##RE##sew##_V;                                     \
311     break;                                                                     \
312   case RISCV::PseudoVSE##sew##_V_M##lmul:                                      \
313     NewOpc = RISCV::VS##lmul##R_V;                                             \
314     break;
315 #define CASE_WHOLE_REGISTER_LMUL(lmul)                                         \
316   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 8)                                        \
317   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 16)                                       \
318   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 32)                                       \
319   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 64)
320 
321   unsigned NewOpc;
322   switch (MI.getOpcode()) {
323     CASE_WHOLE_REGISTER_LMUL(1)
324     CASE_WHOLE_REGISTER_LMUL(2)
325     CASE_WHOLE_REGISTER_LMUL(4)
326     CASE_WHOLE_REGISTER_LMUL(8)
327   default:
328     return false;
329   }
330 
331   MachineOperand &VLOp = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
332   if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
333     return false;
334 
335   // Whole register instructions aren't pseudos so they don't have
336   // policy/SEW/AVL ops, and they don't have passthrus.
337   if (RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags))
338     MI.removeOperand(RISCVII::getVecPolicyOpNum(MI.getDesc()));
339   MI.removeOperand(RISCVII::getSEWOpNum(MI.getDesc()));
340   MI.removeOperand(RISCVII::getVLOpNum(MI.getDesc()));
341   if (RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()))
342     MI.removeOperand(1);
343 
344   MI.setDesc(TII->get(NewOpc));
345 
346   return true;
347 }
348 
349 static unsigned getVMV_V_VOpcodeForVMERGE_VVM(const MachineInstr &MI) {
350 #define CASE_VMERGE_TO_VMV(lmul)                                               \
351   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
352     return RISCV::PseudoVMV_V_V_##lmul;
353   switch (MI.getOpcode()) {
354   default:
355     return 0;
356     CASE_VMERGE_TO_VMV(MF8)
357     CASE_VMERGE_TO_VMV(MF4)
358     CASE_VMERGE_TO_VMV(MF2)
359     CASE_VMERGE_TO_VMV(M1)
360     CASE_VMERGE_TO_VMV(M2)
361     CASE_VMERGE_TO_VMV(M4)
362     CASE_VMERGE_TO_VMV(M8)
363   }
364 }
365 
366 /// Convert a PseudoVMERGE_VVM with an all ones mask to a PseudoVMV_V_V.
367 ///
368 /// %x = PseudoVMERGE_VVM %passthru, %false, %true, %allones, sew, vl
369 /// ->
370 /// %x = PseudoVMV_V_V %passthru, %true, vl, sew, tu_mu
371 bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
372   unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
373   if (!NewOpc)
374     return false;
375   assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
376   if (!isAllOnesMask(V0Defs.lookup(&MI)))
377     return false;
378 
379   MI.setDesc(TII->get(NewOpc));
380   MI.removeOperand(2); // False operand
381   MI.removeOperand(3); // Mask operand
382   MI.addOperand(
383       MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
384 
385   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
386   // register class for the destination and passthru operands e.g. VRNoV0 -> VR
387   MRI->recomputeRegClass(MI.getOperand(0).getReg());
388   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
389     MRI->recomputeRegClass(MI.getOperand(1).getReg());
390   return true;
391 }
392 
393 /// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
394 /// same mask, and the masked pseudo's passthru is the same as the false
395 /// operand, we can convert the PseudoVMERGE_VVM to a PseudoVMV_V_V.
396 ///
397 /// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
398 /// %x = PseudoVMERGE_VVM %passthru, %false, %true, %mask, vl2, sew
399 /// ->
400 /// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
401 /// %x = PseudoVMV_V_V %passthru, %true, vl2, sew, tu_mu
402 bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
403   unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
404   if (!NewOpc)
405     return false;
406   MachineInstr *True = MRI->getVRegDef(MI.getOperand(3).getReg());
407   if (!True || True->getParent() != MI.getParent() ||
408       !RISCV::getMaskedPseudoInfo(True->getOpcode()) || !hasSameEEW(MI, *True))
409     return false;
410 
411   const MachineInstr *TrueV0Def = V0Defs.lookup(True);
412   const MachineInstr *MIV0Def = V0Defs.lookup(&MI);
413   assert(TrueV0Def && TrueV0Def->isCopy() && MIV0Def && MIV0Def->isCopy());
414   if (TrueV0Def->getOperand(1).getReg() != MIV0Def->getOperand(1).getReg())
415     return false;
416 
417   // True's passthru needs to be equivalent to False
418   Register TruePassthruReg = True->getOperand(1).getReg();
419   Register FalseReg = MI.getOperand(2).getReg();
420   if (TruePassthruReg != FalseReg) {
421     // If True's passthru is undef see if we can change it to False
422     if (TruePassthruReg != RISCV::NoRegister ||
423         !MRI->hasOneUse(MI.getOperand(3).getReg()) ||
424         !ensureDominates(MI.getOperand(2), *True))
425       return false;
426     True->getOperand(1).setReg(MI.getOperand(2).getReg());
427     // If True is masked then its passthru needs to be in VRNoV0.
428     MRI->constrainRegClass(True->getOperand(1).getReg(),
429                            TII->getRegClass(True->getDesc(), 1, TRI,
430                                             *True->getParent()->getParent()));
431   }
432 
433   MI.setDesc(TII->get(NewOpc));
434   MI.removeOperand(2); // False operand
435   MI.removeOperand(3); // Mask operand
436   MI.addOperand(
437       MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
438 
439   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
440   // register class for the destination and passthru operands e.g. VRNoV0 -> VR
441   MRI->recomputeRegClass(MI.getOperand(0).getReg());
442   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
443     MRI->recomputeRegClass(MI.getOperand(1).getReg());
444   return true;
445 }
446 
447 bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
448   const RISCV::RISCVMaskedPseudoInfo *I =
449       RISCV::getMaskedPseudoInfo(MI.getOpcode());
450   if (!I)
451     return false;
452 
453   if (!isAllOnesMask(V0Defs.lookup(&MI)))
454     return false;
455 
456   // There are two classes of pseudos in the table - compares and
457   // everything else.  See the comment on RISCVMaskedPseudo for details.
458   const unsigned Opc = I->UnmaskedPseudo;
459   const MCInstrDesc &MCID = TII->get(Opc);
460   [[maybe_unused]] const bool HasPolicyOp =
461       RISCVII::hasVecPolicyOp(MCID.TSFlags);
462   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
463   const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
464   assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
465              RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
466          "Masked and unmasked pseudos are inconsistent");
467   assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
468   assert(!(HasPassthru && !RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) &&
469          "Unmasked with passthru but masked with no passthru?");
470   (void)HasPolicyOp;
471 
472   MI.setDesc(MCID);
473 
474   // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
475   unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
476   MI.removeOperand(MaskOpIdx);
477 
478   // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
479   // so try and relax it to vr.
480   MRI->recomputeRegClass(MI.getOperand(0).getReg());
481 
482   // If the original masked pseudo had a passthru, relax it or remove it.
483   if (RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) {
484     unsigned PassthruOpIdx = MI.getNumExplicitDefs();
485     if (HasPassthru) {
486       if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
487         MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
488     } else
489       MI.removeOperand(PassthruOpIdx);
490   }
491 
492   return true;
493 }
494 
495 /// Check if it's safe to move From down to To, checking that no physical
496 /// registers are clobbered.
497 static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
498   assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
499   SmallVector<Register> PhysUses;
500   for (const MachineOperand &MO : From.all_uses())
501     if (MO.getReg().isPhysical())
502       PhysUses.push_back(MO.getReg());
503   bool SawStore = false;
504   for (auto II = From.getIterator(); II != To.getIterator(); II++) {
505     for (Register PhysReg : PhysUses)
506       if (II->definesRegister(PhysReg, nullptr))
507         return false;
508     if (II->mayStore()) {
509       SawStore = true;
510       break;
511     }
512   }
513   return From.isSafeToMove(SawStore);
514 }
515 
516 /// Given A and B are in the same MBB, returns true if A comes before B.
517 static bool dominates(MachineBasicBlock::const_iterator A,
518                       MachineBasicBlock::const_iterator B) {
519   assert(A->getParent() == B->getParent());
520   const MachineBasicBlock *MBB = A->getParent();
521   auto MBBEnd = MBB->end();
522   if (B == MBBEnd)
523     return true;
524 
525   MachineBasicBlock::const_iterator I = MBB->begin();
526   for (; &*I != A && &*I != B; ++I)
527     ;
528 
529   return &*I == A;
530 }
531 
532 /// If the register in \p MO doesn't dominate \p Src, try to move \p Src so it
533 /// does. Returns false if doesn't dominate and we can't move. \p MO must be in
534 /// the same basic block as \Src.
535 bool RISCVVectorPeephole::ensureDominates(const MachineOperand &MO,
536                                           MachineInstr &Src) const {
537   assert(MO.getParent()->getParent() == Src.getParent());
538   if (!MO.isReg() || MO.getReg() == RISCV::NoRegister)
539     return true;
540 
541   MachineInstr *Def = MRI->getVRegDef(MO.getReg());
542   if (Def->getParent() == Src.getParent() && !dominates(Def, Src)) {
543     if (!isSafeToMove(Src, *Def->getNextNode()))
544       return false;
545     Src.moveBefore(Def->getNextNode());
546   }
547 
548   return true;
549 }
550 
551 /// If a PseudoVMV_V_V's passthru is undef then we can replace it with its input
552 bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
553   if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
554     return false;
555   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
556     return false;
557 
558   // If the input was a pseudo with a policy operand, we can give it a tail
559   // agnostic policy if MI's undef tail subsumes the input's.
560   MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
561   if (Src && !Src->hasUnmodeledSideEffects() &&
562       MRI->hasOneUse(MI.getOperand(2).getReg()) &&
563       RISCVII::hasVLOp(Src->getDesc().TSFlags) &&
564       RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags) && hasSameEEW(MI, *Src)) {
565     const MachineOperand &MIVL = MI.getOperand(3);
566     const MachineOperand &SrcVL =
567         Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
568 
569     MachineOperand &SrcPolicy =
570         Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
571 
572     if (RISCV::isVLKnownLE(MIVL, SrcVL))
573       SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
574   }
575 
576   MRI->replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
577   MI.eraseFromParent();
578   V0Defs.erase(&MI);
579   return true;
580 }
581 
582 /// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
583 /// into it.
584 ///
585 /// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
586 /// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
587 ///    (where %vl1 <= %vl2, see related tryToReduceVL)
588 ///
589 /// ->
590 ///
591 /// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, vl1, sew, policy
592 bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
593   if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
594     return false;
595 
596   MachineOperand &Passthru = MI.getOperand(1);
597 
598   if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
599     return false;
600 
601   MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
602   if (!Src || Src->hasUnmodeledSideEffects() ||
603       Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
604       !RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) ||
605       !RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
606       !RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags))
607     return false;
608 
609   // Src's dest needs to have the same EEW as MI's input.
610   if (!hasSameEEW(MI, *Src))
611     return false;
612 
613   // Src needs to have the same passthru as VMV_V_V
614   MachineOperand &SrcPassthru = Src->getOperand(1);
615   if (SrcPassthru.getReg() != RISCV::NoRegister &&
616       SrcPassthru.getReg() != Passthru.getReg())
617     return false;
618 
619   // Src VL will have already been reduced if legal (see tryToReduceVL),
620   // so we don't need to handle a smaller source VL here.  However, the
621   // user's VL may be larger
622   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
623   if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
624     return false;
625 
626   // If the new passthru doesn't dominate Src, try to move Src so it does.
627   if (!ensureDominates(Passthru, *Src))
628     return false;
629 
630   if (SrcPassthru.getReg() != Passthru.getReg()) {
631     SrcPassthru.setReg(Passthru.getReg());
632     // If Src is masked then its passthru needs to be in VRNoV0.
633     if (Passthru.getReg() != RISCV::NoRegister)
634       MRI->constrainRegClass(Passthru.getReg(),
635                              TII->getRegClass(Src->getDesc(), 1, TRI,
636                                               *Src->getParent()->getParent()));
637   }
638 
639   // If MI was tail agnostic and the VL didn't increase, preserve it.
640   int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
641   if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
642       RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
643     Policy |= RISCVII::TAIL_AGNOSTIC;
644   Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
645 
646   MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
647   MI.eraseFromParent();
648   V0Defs.erase(&MI);
649 
650   return true;
651 }
652 
653 bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
654   if (skipFunction(MF.getFunction()))
655     return false;
656 
657   // Skip if the vector extension is not enabled.
658   ST = &MF.getSubtarget<RISCVSubtarget>();
659   if (!ST->hasVInstructions())
660     return false;
661 
662   TII = ST->getInstrInfo();
663   MRI = &MF.getRegInfo();
664   TRI = MRI->getTargetRegisterInfo();
665 
666   bool Changed = false;
667 
668   // Masked pseudos coming out of isel will have their mask operand in the form:
669   //
670   // $v0:vr = COPY %mask:vr
671   // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
672   //
673   // Because $v0 isn't in SSA, keep track of its definition at each use so we
674   // can check mask operands.
675   for (const MachineBasicBlock &MBB : MF) {
676     const MachineInstr *CurrentV0Def = nullptr;
677     for (const MachineInstr &MI : MBB) {
678       if (MI.readsRegister(RISCV::V0, TRI))
679         V0Defs[&MI] = CurrentV0Def;
680 
681       if (MI.definesRegister(RISCV::V0, TRI))
682         CurrentV0Def = &MI;
683     }
684   }
685 
686   for (MachineBasicBlock &MBB : MF) {
687     for (MachineInstr &MI : make_early_inc_range(MBB)) {
688       Changed |= convertToVLMAX(MI);
689       Changed |= tryToReduceVL(MI);
690       Changed |= convertToUnmasked(MI);
691       Changed |= convertToWholeRegister(MI);
692       Changed |= convertAllOnesVMergeToVMv(MI);
693       Changed |= convertSameMaskVMergeToVMv(MI);
694       if (foldUndefPassthruVMV_V_V(MI)) {
695         Changed |= true;
696         continue; // MI is erased
697       }
698       Changed |= foldVMV_V_V(MI);
699     }
700   }
701 
702   return Changed;
703 }
704 
705 FunctionPass *llvm::createRISCVVectorPeepholePass() {
706   return new RISCVVectorPeephole();
707 }
708