xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1*0fca6ea1SDimitry Andric //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2*0fca6ea1SDimitry Andric //
3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0fca6ea1SDimitry Andric //
7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
8*0fca6ea1SDimitry Andric //
9*0fca6ea1SDimitry Andric // This pass performs various vector pseudo peephole optimisations after
10*0fca6ea1SDimitry Andric // instruction selection.
11*0fca6ea1SDimitry Andric //
12*0fca6ea1SDimitry Andric // Currently it converts vmerge.vvm to vmv.v.v
13*0fca6ea1SDimitry Andric // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14*0fca6ea1SDimitry Andric // ->
15*0fca6ea1SDimitry Andric // PseudoVMV_V_V %false, %true, %vl, %sew
16*0fca6ea1SDimitry Andric //
17*0fca6ea1SDimitry Andric // And masked pseudos to unmasked pseudos
18*0fca6ea1SDimitry Andric // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19*0fca6ea1SDimitry Andric // ->
20*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21*0fca6ea1SDimitry Andric //
22*0fca6ea1SDimitry Andric // It also converts AVLs to VLMAX where possible
23*0fca6ea1SDimitry Andric // %vl = VLENB * something
24*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25*0fca6ea1SDimitry Andric // ->
26*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27*0fca6ea1SDimitry Andric //
28*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
29*0fca6ea1SDimitry Andric 
30*0fca6ea1SDimitry Andric #include "RISCV.h"
31*0fca6ea1SDimitry Andric #include "RISCVISelDAGToDAG.h"
32*0fca6ea1SDimitry Andric #include "RISCVSubtarget.h"
33*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineFunctionPass.h"
34*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
35*0fca6ea1SDimitry Andric #include "llvm/CodeGen/TargetInstrInfo.h"
36*0fca6ea1SDimitry Andric #include "llvm/CodeGen/TargetRegisterInfo.h"
37*0fca6ea1SDimitry Andric 
38*0fca6ea1SDimitry Andric using namespace llvm;
39*0fca6ea1SDimitry Andric 
40*0fca6ea1SDimitry Andric #define DEBUG_TYPE "riscv-vector-peephole"
41*0fca6ea1SDimitry Andric 
42*0fca6ea1SDimitry Andric namespace {
43*0fca6ea1SDimitry Andric 
44*0fca6ea1SDimitry Andric class RISCVVectorPeephole : public MachineFunctionPass {
45*0fca6ea1SDimitry Andric public:
46*0fca6ea1SDimitry Andric   static char ID;
47*0fca6ea1SDimitry Andric   const TargetInstrInfo *TII;
48*0fca6ea1SDimitry Andric   MachineRegisterInfo *MRI;
49*0fca6ea1SDimitry Andric   const TargetRegisterInfo *TRI;
50*0fca6ea1SDimitry Andric   RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51*0fca6ea1SDimitry Andric 
52*0fca6ea1SDimitry Andric   bool runOnMachineFunction(MachineFunction &MF) override;
53*0fca6ea1SDimitry Andric   MachineFunctionProperties getRequiredProperties() const override {
54*0fca6ea1SDimitry Andric     return MachineFunctionProperties().set(
55*0fca6ea1SDimitry Andric         MachineFunctionProperties::Property::IsSSA);
56*0fca6ea1SDimitry Andric   }
57*0fca6ea1SDimitry Andric 
58*0fca6ea1SDimitry Andric   StringRef getPassName() const override { return "RISC-V Fold Masks"; }
59*0fca6ea1SDimitry Andric 
60*0fca6ea1SDimitry Andric private:
61*0fca6ea1SDimitry Andric   bool convertToVLMAX(MachineInstr &MI) const;
62*0fca6ea1SDimitry Andric   bool convertToUnmasked(MachineInstr &MI) const;
63*0fca6ea1SDimitry Andric   bool convertVMergeToVMv(MachineInstr &MI) const;
64*0fca6ea1SDimitry Andric 
65*0fca6ea1SDimitry Andric   bool isAllOnesMask(const MachineInstr *MaskDef) const;
66*0fca6ea1SDimitry Andric 
67*0fca6ea1SDimitry Andric   /// Maps uses of V0 to the corresponding def of V0.
68*0fca6ea1SDimitry Andric   DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
69*0fca6ea1SDimitry Andric };
70*0fca6ea1SDimitry Andric 
71*0fca6ea1SDimitry Andric } // namespace
72*0fca6ea1SDimitry Andric 
73*0fca6ea1SDimitry Andric char RISCVVectorPeephole::ID = 0;
74*0fca6ea1SDimitry Andric 
75*0fca6ea1SDimitry Andric INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
76*0fca6ea1SDimitry Andric                 false)
77*0fca6ea1SDimitry Andric 
78*0fca6ea1SDimitry Andric // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
79*0fca6ea1SDimitry Andric // to the VLMAX sentinel value.
80*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
81*0fca6ea1SDimitry Andric   if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
82*0fca6ea1SDimitry Andric       !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
83*0fca6ea1SDimitry Andric     return false;
84*0fca6ea1SDimitry Andric   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
85*0fca6ea1SDimitry Andric   if (!VL.isReg())
86*0fca6ea1SDimitry Andric     return false;
87*0fca6ea1SDimitry Andric   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
88*0fca6ea1SDimitry Andric   if (!Def)
89*0fca6ea1SDimitry Andric     return false;
90*0fca6ea1SDimitry Andric 
91*0fca6ea1SDimitry Andric   // Fixed-point value, denominator=8
92*0fca6ea1SDimitry Andric   uint64_t ScaleFixed = 8;
93*0fca6ea1SDimitry Andric   // Check if the VLENB was potentially scaled with slli/srli
94*0fca6ea1SDimitry Andric   if (Def->getOpcode() == RISCV::SLLI) {
95*0fca6ea1SDimitry Andric     assert(Def->getOperand(2).getImm() < 64);
96*0fca6ea1SDimitry Andric     ScaleFixed <<= Def->getOperand(2).getImm();
97*0fca6ea1SDimitry Andric     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
98*0fca6ea1SDimitry Andric   } else if (Def->getOpcode() == RISCV::SRLI) {
99*0fca6ea1SDimitry Andric     assert(Def->getOperand(2).getImm() < 64);
100*0fca6ea1SDimitry Andric     ScaleFixed >>= Def->getOperand(2).getImm();
101*0fca6ea1SDimitry Andric     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
102*0fca6ea1SDimitry Andric   }
103*0fca6ea1SDimitry Andric 
104*0fca6ea1SDimitry Andric   if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
105*0fca6ea1SDimitry Andric     return false;
106*0fca6ea1SDimitry Andric 
107*0fca6ea1SDimitry Andric   auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
108*0fca6ea1SDimitry Andric   // Fixed-point value, denominator=8
109*0fca6ea1SDimitry Andric   unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
110*0fca6ea1SDimitry Andric   unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
111*0fca6ea1SDimitry Andric   // A Log2SEW of 0 is an operation on mask registers only
112*0fca6ea1SDimitry Andric   unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
113*0fca6ea1SDimitry Andric   assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
114*0fca6ea1SDimitry Andric   assert(8 * LMULFixed / SEW > 0);
115*0fca6ea1SDimitry Andric 
116*0fca6ea1SDimitry Andric   // AVL = (VLENB * Scale)
117*0fca6ea1SDimitry Andric   //
118*0fca6ea1SDimitry Andric   // VLMAX = (VLENB * 8 * LMUL) / SEW
119*0fca6ea1SDimitry Andric   //
120*0fca6ea1SDimitry Andric   // AVL == VLMAX
121*0fca6ea1SDimitry Andric   // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
122*0fca6ea1SDimitry Andric   // -> Scale == (8 * LMUL) / SEW
123*0fca6ea1SDimitry Andric   if (ScaleFixed != 8 * LMULFixed / SEW)
124*0fca6ea1SDimitry Andric     return false;
125*0fca6ea1SDimitry Andric 
126*0fca6ea1SDimitry Andric   VL.ChangeToImmediate(RISCV::VLMaxSentinel);
127*0fca6ea1SDimitry Andric 
128*0fca6ea1SDimitry Andric   return true;
129*0fca6ea1SDimitry Andric }
130*0fca6ea1SDimitry Andric 
131*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
132*0fca6ea1SDimitry Andric   assert(MaskDef && MaskDef->isCopy() &&
133*0fca6ea1SDimitry Andric          MaskDef->getOperand(0).getReg() == RISCV::V0);
134*0fca6ea1SDimitry Andric   Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
135*0fca6ea1SDimitry Andric   if (!SrcReg.isVirtual())
136*0fca6ea1SDimitry Andric     return false;
137*0fca6ea1SDimitry Andric   MaskDef = MRI->getVRegDef(SrcReg);
138*0fca6ea1SDimitry Andric   if (!MaskDef)
139*0fca6ea1SDimitry Andric     return false;
140*0fca6ea1SDimitry Andric 
141*0fca6ea1SDimitry Andric   // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
142*0fca6ea1SDimitry Andric   // undefined behaviour if it's the wrong bitwidth, so we could choose to
143*0fca6ea1SDimitry Andric   // assume that it's all-ones? Same applies to its VL.
144*0fca6ea1SDimitry Andric   switch (MaskDef->getOpcode()) {
145*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B1:
146*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B2:
147*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B4:
148*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B8:
149*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B16:
150*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B32:
151*0fca6ea1SDimitry Andric   case RISCV::PseudoVMSET_M_B64:
152*0fca6ea1SDimitry Andric     return true;
153*0fca6ea1SDimitry Andric   default:
154*0fca6ea1SDimitry Andric     return false;
155*0fca6ea1SDimitry Andric   }
156*0fca6ea1SDimitry Andric }
157*0fca6ea1SDimitry Andric 
158*0fca6ea1SDimitry Andric // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
159*0fca6ea1SDimitry Andric // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
160*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const {
161*0fca6ea1SDimitry Andric #define CASE_VMERGE_TO_VMV(lmul)                                               \
162*0fca6ea1SDimitry Andric   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
163*0fca6ea1SDimitry Andric     NewOpc = RISCV::PseudoVMV_V_V_##lmul;                                      \
164*0fca6ea1SDimitry Andric     break;
165*0fca6ea1SDimitry Andric   unsigned NewOpc;
166*0fca6ea1SDimitry Andric   switch (MI.getOpcode()) {
167*0fca6ea1SDimitry Andric   default:
168*0fca6ea1SDimitry Andric     return false;
169*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(MF8)
170*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(MF4)
171*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(MF2)
172*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(M1)
173*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(M2)
174*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(M4)
175*0fca6ea1SDimitry Andric     CASE_VMERGE_TO_VMV(M8)
176*0fca6ea1SDimitry Andric   }
177*0fca6ea1SDimitry Andric 
178*0fca6ea1SDimitry Andric   Register MergeReg = MI.getOperand(1).getReg();
179*0fca6ea1SDimitry Andric   Register FalseReg = MI.getOperand(2).getReg();
180*0fca6ea1SDimitry Andric   // Check merge == false (or merge == undef)
181*0fca6ea1SDimitry Andric   if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
182*0fca6ea1SDimitry Andric                                            TRI->lookThruCopyLike(FalseReg, MRI))
183*0fca6ea1SDimitry Andric     return false;
184*0fca6ea1SDimitry Andric 
185*0fca6ea1SDimitry Andric   assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
186*0fca6ea1SDimitry Andric   if (!isAllOnesMask(V0Defs.lookup(&MI)))
187*0fca6ea1SDimitry Andric     return false;
188*0fca6ea1SDimitry Andric 
189*0fca6ea1SDimitry Andric   MI.setDesc(TII->get(NewOpc));
190*0fca6ea1SDimitry Andric   MI.removeOperand(1);  // Merge operand
191*0fca6ea1SDimitry Andric   MI.tieOperands(0, 1); // Tie false to dest
192*0fca6ea1SDimitry Andric   MI.removeOperand(3);  // Mask operand
193*0fca6ea1SDimitry Andric   MI.addOperand(
194*0fca6ea1SDimitry Andric       MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
195*0fca6ea1SDimitry Andric 
196*0fca6ea1SDimitry Andric   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
197*0fca6ea1SDimitry Andric   // register class for the destination and merge operands e.g. VRNoV0 -> VR
198*0fca6ea1SDimitry Andric   MRI->recomputeRegClass(MI.getOperand(0).getReg());
199*0fca6ea1SDimitry Andric   MRI->recomputeRegClass(MI.getOperand(1).getReg());
200*0fca6ea1SDimitry Andric   return true;
201*0fca6ea1SDimitry Andric }
202*0fca6ea1SDimitry Andric 
203*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
204*0fca6ea1SDimitry Andric   const RISCV::RISCVMaskedPseudoInfo *I =
205*0fca6ea1SDimitry Andric       RISCV::getMaskedPseudoInfo(MI.getOpcode());
206*0fca6ea1SDimitry Andric   if (!I)
207*0fca6ea1SDimitry Andric     return false;
208*0fca6ea1SDimitry Andric 
209*0fca6ea1SDimitry Andric   if (!isAllOnesMask(V0Defs.lookup(&MI)))
210*0fca6ea1SDimitry Andric     return false;
211*0fca6ea1SDimitry Andric 
212*0fca6ea1SDimitry Andric   // There are two classes of pseudos in the table - compares and
213*0fca6ea1SDimitry Andric   // everything else.  See the comment on RISCVMaskedPseudo for details.
214*0fca6ea1SDimitry Andric   const unsigned Opc = I->UnmaskedPseudo;
215*0fca6ea1SDimitry Andric   const MCInstrDesc &MCID = TII->get(Opc);
216*0fca6ea1SDimitry Andric   [[maybe_unused]] const bool HasPolicyOp =
217*0fca6ea1SDimitry Andric       RISCVII::hasVecPolicyOp(MCID.TSFlags);
218*0fca6ea1SDimitry Andric   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
219*0fca6ea1SDimitry Andric #ifndef NDEBUG
220*0fca6ea1SDimitry Andric   const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
221*0fca6ea1SDimitry Andric   assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
222*0fca6ea1SDimitry Andric              RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
223*0fca6ea1SDimitry Andric          "Masked and unmasked pseudos are inconsistent");
224*0fca6ea1SDimitry Andric   assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
225*0fca6ea1SDimitry Andric #endif
226*0fca6ea1SDimitry Andric   (void)HasPolicyOp;
227*0fca6ea1SDimitry Andric 
228*0fca6ea1SDimitry Andric   MI.setDesc(MCID);
229*0fca6ea1SDimitry Andric 
230*0fca6ea1SDimitry Andric   // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
231*0fca6ea1SDimitry Andric   unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
232*0fca6ea1SDimitry Andric   MI.removeOperand(MaskOpIdx);
233*0fca6ea1SDimitry Andric 
234*0fca6ea1SDimitry Andric   // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
235*0fca6ea1SDimitry Andric   // so try and relax it to vr.
236*0fca6ea1SDimitry Andric   MRI->recomputeRegClass(MI.getOperand(0).getReg());
237*0fca6ea1SDimitry Andric   unsigned PassthruOpIdx = MI.getNumExplicitDefs();
238*0fca6ea1SDimitry Andric   if (HasPassthru) {
239*0fca6ea1SDimitry Andric     if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
240*0fca6ea1SDimitry Andric       MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
241*0fca6ea1SDimitry Andric   } else
242*0fca6ea1SDimitry Andric     MI.removeOperand(PassthruOpIdx);
243*0fca6ea1SDimitry Andric 
244*0fca6ea1SDimitry Andric   return true;
245*0fca6ea1SDimitry Andric }
246*0fca6ea1SDimitry Andric 
247*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
248*0fca6ea1SDimitry Andric   if (skipFunction(MF.getFunction()))
249*0fca6ea1SDimitry Andric     return false;
250*0fca6ea1SDimitry Andric 
251*0fca6ea1SDimitry Andric   // Skip if the vector extension is not enabled.
252*0fca6ea1SDimitry Andric   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
253*0fca6ea1SDimitry Andric   if (!ST.hasVInstructions())
254*0fca6ea1SDimitry Andric     return false;
255*0fca6ea1SDimitry Andric 
256*0fca6ea1SDimitry Andric   TII = ST.getInstrInfo();
257*0fca6ea1SDimitry Andric   MRI = &MF.getRegInfo();
258*0fca6ea1SDimitry Andric   TRI = MRI->getTargetRegisterInfo();
259*0fca6ea1SDimitry Andric 
260*0fca6ea1SDimitry Andric   bool Changed = false;
261*0fca6ea1SDimitry Andric 
262*0fca6ea1SDimitry Andric   // Masked pseudos coming out of isel will have their mask operand in the form:
263*0fca6ea1SDimitry Andric   //
264*0fca6ea1SDimitry Andric   // $v0:vr = COPY %mask:vr
265*0fca6ea1SDimitry Andric   // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
266*0fca6ea1SDimitry Andric   //
267*0fca6ea1SDimitry Andric   // Because $v0 isn't in SSA, keep track of its definition at each use so we
268*0fca6ea1SDimitry Andric   // can check mask operands.
269*0fca6ea1SDimitry Andric   for (const MachineBasicBlock &MBB : MF) {
270*0fca6ea1SDimitry Andric     const MachineInstr *CurrentV0Def = nullptr;
271*0fca6ea1SDimitry Andric     for (const MachineInstr &MI : MBB) {
272*0fca6ea1SDimitry Andric       if (MI.readsRegister(RISCV::V0, TRI))
273*0fca6ea1SDimitry Andric         V0Defs[&MI] = CurrentV0Def;
274*0fca6ea1SDimitry Andric 
275*0fca6ea1SDimitry Andric       if (MI.definesRegister(RISCV::V0, TRI))
276*0fca6ea1SDimitry Andric         CurrentV0Def = &MI;
277*0fca6ea1SDimitry Andric     }
278*0fca6ea1SDimitry Andric   }
279*0fca6ea1SDimitry Andric 
280*0fca6ea1SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
281*0fca6ea1SDimitry Andric     for (MachineInstr &MI : MBB) {
282*0fca6ea1SDimitry Andric       Changed |= convertToVLMAX(MI);
283*0fca6ea1SDimitry Andric       Changed |= convertToUnmasked(MI);
284*0fca6ea1SDimitry Andric       Changed |= convertVMergeToVMv(MI);
285*0fca6ea1SDimitry Andric     }
286*0fca6ea1SDimitry Andric   }
287*0fca6ea1SDimitry Andric 
288*0fca6ea1SDimitry Andric   return Changed;
289*0fca6ea1SDimitry Andric }
290*0fca6ea1SDimitry Andric 
291*0fca6ea1SDimitry Andric FunctionPass *llvm::createRISCVVectorPeepholePass() {
292*0fca6ea1SDimitry Andric   return new RISCVVectorPeephole();
293*0fca6ea1SDimitry Andric }
294