xref: /llvm-project/llvm/lib/Target/RISCV/RISCVVectorMaskDAGMutation.cpp (revision 01a15dca09e56dce850ab6fb3ecddfb3f8c6c172)
1 //===- RISCVVectorMaskDAGMutation.cpp - RISC-V Vector Mask DAGMutation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A schedule mutation that adds an artificial dependency between masks producer
10 // instructions and masked instructions, so that we can reduce the live range
11 // overlaps of mask registers.
12 //
13 // The reason why we need to do this:
14 // 1. When tracking register pressure, we don't track physical registers.
15 // 2. We have a RegisterClass for mask reigster (which is `VMV0`), but we don't
16 //    use it in most RVV pseudos (only used in inline asm constraint and add/sub
17 //    with carry instructions). Instead, we use physical register V0 directly
18 //    and insert a `$v0 = COPY ...` before the use. And, there is a fundamental
19 //    issue in register allocator when handling RegisterClass with only one
20 //    physical register, so we can't simply replace V0 with VMV0.
21 // 3. For mask producers, we are using VR RegisterClass (we can allocate V0-V31
22 //    to it). So if V0 is not available, there are still 31 available registers
23 //    out there.
24 //
25 // This means that the RegPressureTracker can't track the pressure of mask
26 // registers correctly.
27 //
28 // This schedule mutation is a workaround to fix this issue.
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #include "MCTargetDesc/RISCVBaseInfo.h"
33 #include "MCTargetDesc/RISCVMCTargetDesc.h"
34 #include "RISCVRegisterInfo.h"
35 #include "RISCVTargetMachine.h"
36 #include "llvm/CodeGen/LiveIntervals.h"
37 #include "llvm/CodeGen/MachineInstr.h"
38 #include "llvm/CodeGen/ScheduleDAGInstrs.h"
39 #include "llvm/CodeGen/ScheduleDAGMutation.h"
40 #include "llvm/TargetParser/RISCVTargetParser.h"
41 
42 #define DEBUG_TYPE "machine-scheduler"
43 
44 namespace llvm {
45 
46 static inline bool isVectorMaskProducer(const MachineInstr *MI) {
47   switch (RISCV::getRVVMCOpcode(MI->getOpcode())) {
48   // Vector Mask Instructions
49   case RISCV::VMAND_MM:
50   case RISCV::VMNAND_MM:
51   case RISCV::VMANDN_MM:
52   case RISCV::VMXOR_MM:
53   case RISCV::VMOR_MM:
54   case RISCV::VMNOR_MM:
55   case RISCV::VMORN_MM:
56   case RISCV::VMXNOR_MM:
57   case RISCV::VMSBF_M:
58   case RISCV::VMSIF_M:
59   case RISCV::VMSOF_M:
60   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
61   case RISCV::VMADC_VV:
62   case RISCV::VMADC_VX:
63   case RISCV::VMADC_VI:
64   case RISCV::VMADC_VVM:
65   case RISCV::VMADC_VXM:
66   case RISCV::VMADC_VIM:
67   case RISCV::VMSBC_VV:
68   case RISCV::VMSBC_VX:
69   case RISCV::VMSBC_VVM:
70   case RISCV::VMSBC_VXM:
71   // Vector Integer Compare Instructions
72   case RISCV::VMSEQ_VV:
73   case RISCV::VMSEQ_VX:
74   case RISCV::VMSEQ_VI:
75   case RISCV::VMSNE_VV:
76   case RISCV::VMSNE_VX:
77   case RISCV::VMSNE_VI:
78   case RISCV::VMSLT_VV:
79   case RISCV::VMSLT_VX:
80   case RISCV::VMSLTU_VV:
81   case RISCV::VMSLTU_VX:
82   case RISCV::VMSLE_VV:
83   case RISCV::VMSLE_VX:
84   case RISCV::VMSLE_VI:
85   case RISCV::VMSLEU_VV:
86   case RISCV::VMSLEU_VX:
87   case RISCV::VMSLEU_VI:
88   case RISCV::VMSGTU_VX:
89   case RISCV::VMSGTU_VI:
90   case RISCV::VMSGT_VX:
91   case RISCV::VMSGT_VI:
92   // Vector Floating-Point Compare Instructions
93   case RISCV::VMFEQ_VV:
94   case RISCV::VMFEQ_VF:
95   case RISCV::VMFNE_VV:
96   case RISCV::VMFNE_VF:
97   case RISCV::VMFLT_VV:
98   case RISCV::VMFLT_VF:
99   case RISCV::VMFLE_VV:
100   case RISCV::VMFLE_VF:
101   case RISCV::VMFGT_VF:
102   case RISCV::VMFGE_VF:
103     return true;
104   }
105   return false;
106 }
107 
108 class RISCVVectorMaskDAGMutation : public ScheduleDAGMutation {
109 private:
110   const TargetRegisterInfo *TRI;
111 
112 public:
113   RISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) : TRI(TRI) {}
114 
115   void apply(ScheduleDAGInstrs *DAG) override {
116     SUnit *NearestUseV0SU = nullptr;
117     for (SUnit &SU : DAG->SUnits) {
118       const MachineInstr *MI = SU.getInstr();
119       if (MI->findRegisterUseOperand(RISCV::V0, TRI))
120         NearestUseV0SU = &SU;
121 
122       if (NearestUseV0SU && NearestUseV0SU != &SU && isVectorMaskProducer(MI) &&
123           // For LMUL=8 cases, there will be more possibilities to spill.
124           // FIXME: We should use RegPressureTracker to do fine-grained
125           // controls.
126           RISCVII::getLMul(MI->getDesc().TSFlags) != RISCVII::LMUL_8)
127         DAG->addEdge(&SU, SDep(NearestUseV0SU, SDep::Artificial));
128     }
129   }
130 };
131 
132 std::unique_ptr<ScheduleDAGMutation>
133 createRISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) {
134   return std::make_unique<RISCVVectorMaskDAGMutation>(TRI);
135 }
136 
137 } // namespace llvm
138